diff --git a/tutorials/tutorial_pydanticAI/.dockerignore b/tutorials/tutorial_pydanticAI/.dockerignore new file mode 100644 index 000000000..fd85b2584 --- /dev/null +++ b/tutorials/tutorial_pydanticAI/.dockerignore @@ -0,0 +1,143 @@ +# Exclude files from Docker build context. This prevents unnecessary files from +# being sent to Docker daemon, reducing build time and image size. + +# Python artifacts +__pycache__/ +*.pyc +*.pyo +*.pyd +*.egg-info/ + +# Virtual environments +venv/ +.venv/ +env/ +.env +.envrc +client_venv.helpers/ +ENV/ + +# Jupyter +.ipynb_checkpoints/ +.jupyter/ + +# Build artifacts +build/ +dist/ +*.eggs/ +.eggs/ + +# Cache and temporary files +*.log +*.tmp +*.cache +.pytest_cache/ +.mypy_cache/ +.coverage +htmlcov/ + +# Git and version control +.git/ +.gitignore +.gitattributes +.github/ + +# Docker build scripts (not needed at runtime) +docker_build.sh +docker_push.sh +docker_clean.sh +docker_exec.sh +docker_cmd.sh +docker_bash.sh +docker_jupyter.sh +docker_name.sh +run_jupyter.sh +Dockerfile.* +.dockerignore + +# Documentation +README.md +README.admin.md +docs/ +*.md +CHANGELOG.md +LICENSE + +# Configuration and secrets +.env.* +.env.local +.env.development +.env.production +.DS_Store +Thumbs.db + +# Shell configuration +.bashrc +.bash_history +.zshrc + +# Large data files (mount via volume instead) +data/ +*.csv +*.pkl +*.h5 +*.parquet +*.feather +*.arrow +*.npy +*.npz + +# Generated images +*.png +*.jpg +*.jpeg +*.gif +*.svg +*.pdf + +# Test files and examples +tests/ +test_* +*_test.py +tutorials/ +examples/ + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*.swo +*~ +.project +.pydevproject +.settings/ +*.iml +.sublime-project +.sublime-workspace + +# Node and frontend (if applicable) +node_modules/ +npm-debug.log +yarn-error.log +.npm + +# Requirements management +requirements.in +Pipfile +Pipfile.lock +poetry.lock +setup.py +setup.cfg + +# CI/CD configuration +.gitlab-ci.yml +.travis.yml +Jenkinsfile +.circleci/ + +# Miscellaneous +*.bak +.venv.bak/ +*.whl +*.tar.gz +*.zip diff --git a/tutorials/tutorial_pydanticAI/.gitignore b/tutorials/tutorial_pydanticAI/.gitignore new file mode 100644 index 000000000..f69248928 --- /dev/null +++ b/tutorials/tutorial_pydanticAI/.gitignore @@ -0,0 +1,93 @@ +# OS files +.DS_Store +Thumbs.db +desktop.ini + +# Editor / IDE files +.vscode/ +.idea/ +*.swp +*.swo + +# Environment variables +.env +.env.* +!.env.example + +# Python +__pycache__/ +*.py[cod] +*.pyo +*.pyd +.Python +venv/ +env/ +.venv/ +ENV/ +pip-wheel-metadata/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +.coverage +htmlcov/ +dist/ +build/ +*.egg-info/ + +# Jupyter +.ipynb_checkpoints/ + +# Node / JavaScript +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +package-lock.json +yarn.lock +pnpm-lock.yaml + +# Logs +*.log +logs/ + +# Build outputs +out/ +target/ +bin/ +obj/ + +# Temporary files +tmp/ +temp/ +*.tmp +*.bak +*.old +tmp.pytest.log +tmp.system_output.txt +tmp.system_cmd.sh +.codex + +# Data / local artifacts +data/ +datasets/ +artifacts/ +outputs/ +models/ +checkpoints/ + +# Docker / local config +docker-compose.override.yml + +# ML / experiment tracking +mlruns/ +wandb/ +lightning_logs/ + +# Secrets / credentials +*.pem +*.key +*.crt +credentials.json +token.json +secrets.json \ No newline at end of file diff --git a/tutorials/tutorial_pydanticAI/Dockerfile.python_slim b/tutorials/tutorial_pydanticAI/Dockerfile.python_slim new file mode 100644 index 000000000..cc8f18f2f --- /dev/null +++ b/tutorials/tutorial_pydanticAI/Dockerfile.python_slim @@ -0,0 +1,28 @@ +# Use Python 3.12 slim (already has Python and pip). +FROM python:3.12-slim + +# Avoid interactive prompts during apt operations. +ENV DEBIAN_FRONTEND=noninteractive + +# Install CA certificates (needed for HTTPS). +RUN apt-get update && apt-get install -y \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Install project specific packages. +RUN mkdir -p /install +COPY requirements.txt /install/requirements.txt +RUN pip install --upgrade pip && \ + pip install --no-cache-dir jupyterlab jupyterlab_vim jupytext -r /install/requirements.txt + +# Config. +COPY etc_sudoers /install/ +COPY etc_sudoers /etc/sudoers +COPY bashrc /root/.bashrc + +# Report package versions. +COPY version.sh /install/ +RUN /install/version.sh 2>&1 | tee version.log + +# Jupyter. +EXPOSE 8888 diff --git a/tutorials/tutorial_pydanticAI/README.md b/tutorials/tutorial_pydanticAI/README.md index 161343506..b2d402aa2 100644 --- a/tutorials/tutorial_pydanticAI/README.md +++ b/tutorials/tutorial_pydanticAI/README.md @@ -1,70 +1,56 @@ - - -- [Project files](#project-files) -- [Setup and Dependencies](#setup-and-dependencies) - * [Building and Running the Docker Container](#building-and-running-the-docker-container) - + [Environment Setup](#environment-setup) - - - -# Project files - -This project contains the following files. - -- `README.md`: This file -- `pydanticai.API.ipynb`: notebook describing core PydanticAI APIs -- `pydanticai.example.ipynb`: notebook with applied, end-to-end examples -- `requirements.txt`: Python dependencies used by this tutorial -- `example_dataset/`: supporting markdown files used in examples - - `api.md` - - `billing.md` - - `integrations.md` - - `limits.md` - - `overview.md` - - `security.md` - - `support.md` - - `troubleshooting.md` -- Docker/dev runtime files - - `Dockerfile` - - `docker_build.sh` - - `docker_bash.sh` - - `docker_jupyter.sh` - - `docker_exec.sh` - - `docker_cmd.sh` - - `docker_clean.sh` - - `docker_push.sh` - - `docker_name.sh` - - `version.sh` - - `run_jupyter.sh` - - `etc_sudoers` - -# Setup and Dependencies - -## Building and Running the Docker Container - -- Go to the project directory: - ```bash - > cd tutorials/tutorial_pydanticAI - ``` -- Build Docker image: - ```bash - > ./docker_build.sh - ``` -- Run container shell: - ```bash - > ./docker_bash.sh - ``` -- Launch Jupyter Notebook: - ```bash - > ./docker_jupyter.sh - ``` - -### Environment Setup - -Set the `OPENAI_API_KEY` environment variable for API access: - -```python -import os -os.environ["OPENAI_API_KEY"] = "" +# PydanticAI Tutorial + +This folder contains the setup for running PydanticAI tutorials within a +containerized environment. + +## Quick Start + +From the root of the repository, change your directory to the PydanticAI +tutorial folder: + +```bash +> cd tutorials/tutorial_pydanticAI ``` +Once the location has been changed to the repo run the command to build the +image to run dockers: + +```bash +> ./docker_build.sh +``` + +Once the docker has been built you can then go ahead and run the container and +launch jupyter notebook using the created image using the command: + +```bash +> ./docker_jupyter.sh +``` + +Once the `./docker_jupyter.sh` script is running, work through the following +notebooks in order. + +For more information on the Docker build system refer to [Project template +README](/class_project/project_template/README.md) + +## Tutorial Notebooks + +Work through the following notebooks in order: + +- [`pydanticai.API.ipynb`](pydanticai.API.ipynb): Core PydanticAI fundamentals + - Understanding the PydanticAI framework architecture + - Working with PydanticAI classes and methods + - Building basic agent configurations + - Integration with language models + +- [`pydanticai.example.ipynb`](pydanticai.example.ipynb): Real-world application + workflow + - End-to-end agentic application example + - Practical problem-solving with PydanticAI + - Advanced agent interactions and workflows + - Best practices and patterns + +- [`pydanticai_API_utils.py`](pydanticai_API_utils.py): Utility functions + supporting the API tutorial notebook + +- [`pydanticai_example_utils.py`](pydanticai_example_utils.py): Utility + functions supporting the example tutorial notebook diff --git a/tutorials/tutorial_pydanticAI/copy_docker_files.py b/tutorials/tutorial_pydanticAI/copy_docker_files.py new file mode 100644 index 000000000..0e97c194c --- /dev/null +++ b/tutorials/tutorial_pydanticAI/copy_docker_files.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python + +""" +Copy Docker-related files from the source directory to a destination directory. + +This script copies all Docker configuration and utility files from +class_project/project_template/ to a specified destination directory. + +Usage examples: + # Copy all files to a target directory. + > ./copy_docker_files.py --dst_dir /path/to/destination + + # Copy with verbose logging. + > ./copy_docker_files.py --dst_dir /path/to/destination -v DEBUG + +Import as: + +import class_project.project_template.copy_docker_files as cpdccodo +""" + +import argparse +import logging +import os +from typing import List + +import helpers.hdbg as hdbg +import helpers.hio as hio +import helpers.hparser as hparser +import helpers.hsystem as hsystem + +_LOG = logging.getLogger(__name__) + +# ############################################################################# +# Constants +# ############################################################################# + +# List of files to copy from the source directory. +_FILES_TO_COPY = [ + "bashrc", + "docker_bash.sh", + "docker_build.sh", + "docker_clean.sh", + "docker_cmd.sh", + "docker_exec.sh", + "docker_jupyter.sh", + "docker_name.sh", + "docker_push.sh", + "etc_sudoers", + "install_jupyter_extensions.sh", + "run_jupyter.sh" + "version.sh", +] + + +# ############################################################################# +# Helper functions +# ############################################################################# + + +def _get_source_dir() -> str: + """ + Get the absolute path to the source directory containing Docker files. + + :return: absolute path to class_project/project_template/ + """ + # Get the directory where this script is located. + script_dir = os.path.dirname(os.path.abspath(__file__)) + _LOG.debug("Script directory='%s'", script_dir) + return script_dir + + +def _copy_files( + *, + src_dir: str, + dst_dir: str, + files: List[str], +) -> None: + """ + Copy specified files from source directory to destination directory. + + :param src_dir: source directory path + :param dst_dir: destination directory path + :param files: list of filenames to copy + """ + # Verify source directory exists. + hdbg.dassert_dir_exists(src_dir, "Source directory does not exist:", src_dir) + # Create destination directory if it doesn't exist. + hio.create_dir(dst_dir, incremental=True) + _LOG.info("Copying %d files from '%s' to '%s'", len(files), src_dir, dst_dir) + # Copy each file. + copied_count = 0 + for filename in files: + src_path = os.path.join(src_dir, filename) + dst_path = os.path.join(dst_dir, filename) + # Verify source file exists. + hdbg.dassert_path_exists( + src_path, "Source file does not exist:", src_path + ) + # Copy the file using cp -a to preserve all permissions and attributes. + _LOG.debug("Copying '%s' -> '%s'", src_path, dst_path) + cmd = f"cp -a {src_path} {dst_path}" + hsystem.system(cmd) + copied_count += 1 + # + _LOG.info("Successfully copied %d files", copied_count) + + +# ############################################################################# + + +def _parse() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--dst_dir", + action="store", + required=True, + help="Destination directory where files will be copied", + ) + hparser.add_verbosity_arg(parser) + return parser + + +def _main(parser: argparse.ArgumentParser) -> None: + args = parser.parse_args() + hdbg.init_logger(verbosity=args.log_level, use_exec_path=True) + # Get source directory. + src_dir = _get_source_dir() + # Copy files to destination. + _copy_files( + src_dir=src_dir, + dst_dir=args.dst_dir, + files=_FILES_TO_COPY, + ) + + +if __name__ == "__main__": + _main(_parse()) diff --git a/tutorials/tutorial_pydanticAI/docker_build.version.log b/tutorials/tutorial_pydanticAI/docker_build.version.log deleted file mode 100644 index ac7068564..000000000 --- a/tutorials/tutorial_pydanticAI/docker_build.version.log +++ /dev/null @@ -1,180 +0,0 @@ -+ cd ../tmp.build -++ pwd -+ docker run --rm -it -v /home/aver23/src/aver81/umd_classes1/tutorials/tmp.build:/data gpsaggese/umd_pydanticai_tutorial bash -c /data/version.sh -# Python3 -Python 3.12.3 -# pip3 -pip 26.0.1 from /venv/lib/python3.12/site-packages/pip (python 3.12) -# jupyter -Selected Jupyter core packages... -IPython : 9.10.0 -ipykernel : 6.30.1 -ipywidgets : not installed -jupyter_client : 8.8.0 -jupyter_core : 5.9.1 -jupyter_server : 2.17.0 -jupyterlab : 4.4.6 -nbclient : 0.10.2 -nbconvert : 7.17.0 -nbformat : 5.10.4 -notebook : not installed -qtconsole : not installed -traitlets : 5.14.3 -# Python packages -Package Version -------------------------- ----------- -aiohappyeyeballs 2.6.1 -aiohttp 3.13.3 -aiosignal 1.4.0 -annotated-types 0.7.0 -anthropic 0.84.0 -anyio 4.12.1 -argon2-cffi 25.1.0 -argon2-cffi-bindings 25.1.0 -arrow 1.4.0 -asttokens 3.0.1 -async-lru 2.2.0 -attrs 25.4.0 -babel 2.18.0 -beautifulsoup4 4.14.3 -bleach 6.3.0 -bracex 2.6 -certifi 2026.2.25 -cffi 2.0.0 -charset-normalizer 3.4.4 -click 8.3.1 -comm 0.2.3 -contourpy 1.3.3 -cryptography 46.0.5 -cycler 0.12.1 -debugpy 1.8.20 -decorator 5.2.1 -deepagents 0.3.11 -defusedxml 0.7.1 -distro 1.9.0 -docstring_parser 0.17.0 -entrypoints 0.4 -executing 2.2.1 -fastjsonschema 2.21.2 -filetype 1.2.0 -fonttools 4.61.1 -fqdn 1.5.1 -frozenlist 1.8.0 -google-auth 2.48.0 -google-genai 1.65.0 -h11 0.16.0 -httpcore 1.0.9 -httpx 0.28.1 -idna 3.11 -ipykernel 6.30.1 -ipython 9.10.0 -ipython_pygments_lexers 1.1.1 -isoduration 20.11.0 -jedi 0.19.2 -Jinja2 3.1.6 -jiter 0.13.0 -json5 0.13.0 -jsonpatch 1.33 -jsonpointer 3.0.0 -jsonschema 4.26.0 -jsonschema-specifications 2025.9.1 -jupyter_client 8.8.0 -jupyter_core 5.9.1 -jupyter-events 0.12.0 -jupyter-lsp 2.3.0 -jupyter_server 2.17.0 -jupyter_server_terminals 0.5.4 -jupyterlab 4.4.6 -jupyterlab_pygments 0.3.0 -jupyterlab_server 2.28.0 -kiwisolver 1.4.9 -langchain 1.2.8 -langchain-anthropic 1.3.1 -langchain-core 1.2.8 -langchain-google-genai 4.2.1 -langchain-openai 1.1.7 -langgraph 1.0.7 -langgraph-checkpoint 4.0.1 -langgraph-prebuilt 1.0.8 -langgraph-sdk 0.3.9 -langsmith 0.7.9 -lark 1.3.1 -MarkupSafe 3.0.3 -matplotlib 3.10.5 -matplotlib-inline 0.2.1 -mistune 3.2.0 -multidict 6.7.1 -nbclient 0.10.2 -nbconvert 7.17.0 -nbformat 5.10.4 -nest-asyncio 1.6.0 -notebook_shim 0.2.4 -numpy 2.3.2 -openai 2.24.0 -orjson 3.11.7 -ormsgpack 1.12.2 -packaging 26.0 -pandas 2.3.2 -pandocfilters 1.5.1 -papermill 2.7.0 -parso 0.8.6 -pexpect 4.9.0 -pillow 12.1.1 -pip 26.0.1 -platformdirs 4.9.2 -prometheus_client 0.24.1 -prompt_toolkit 3.0.52 -propcache 0.4.1 -psutil 7.2.2 -ptyprocess 0.7.0 -pure_eval 0.2.3 -pyasn1 0.6.2 -pyasn1_modules 0.4.2 -pycparser 3.0 -pydantic 2.11.7 -pydantic_core 2.33.2 -Pygments 2.19.2 -pyparsing 3.3.2 -python-dateutil 2.9.0.post0 -python-dotenv 1.1.1 -python-json-logger 4.0.0 -pytz 2025.2 -PyYAML 6.0.3 -pyzmq 27.1.0 -referencing 0.37.0 -regex 2026.2.28 -requests 2.32.5 -requests-toolbelt 1.0.0 -rfc3339-validator 0.1.4 -rfc3986-validator 0.1.1 -rfc3987-syntax 1.1.0 -rpds-py 0.30.0 -rsa 4.9.1 -Send2Trash 2.1.0 -setuptools 82.0.0 -six 1.17.0 -sniffio 1.3.1 -soupsieve 2.8.3 -stack-data 0.6.3 -tenacity 9.1.4 -terminado 0.18.1 -tiktoken 0.12.0 -tinycss2 1.4.0 -tornado 6.5.4 -tqdm 4.67.3 -traitlets 5.14.3 -typing_extensions 4.14.1 -typing-inspection 0.4.2 -tzdata 2025.3 -uri-template 1.3.0 -urllib3 2.6.3 -uuid_utils 0.14.1 -wcmatch 10.1 -wcwidth 0.6.0 -webcolors 25.10.0 -webencodings 0.5.1 -websocket-client 1.9.0 -websockets 16.0 -xxhash 3.6.0 -yarl 1.22.0 -zstandard 0.25.0 diff --git a/tutorials/tutorial_pydanticAI/pydanticai.API.ipynb b/tutorials/tutorial_pydanticAI/pydanticai.API.ipynb index bd51d93b9..5fb65816b 100644 --- a/tutorials/tutorial_pydanticAI/pydanticai.API.ipynb +++ b/tutorials/tutorial_pydanticAI/pydanticai.API.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "72ad5d56", "metadata": {}, "outputs": [], @@ -10,206 +10,236 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", + "# System libraries.\n", "import logging\n", "\n", - "\n", - "import helpers.hnotebook as ut\n", - "\n", - "ut.config_notebook()\n", - "\n", - "# Initialize logger.\n", - "logging.basicConfig(level=logging.INFO)\n", - "_LOG = logging.getLogger(__name__)" + "# Third party libraries.\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "a066f6ee", "metadata": {}, "outputs": [], "source": [ - "import pydanticai_API_utils as utils" + "# System libraries.\n", + "import asyncio\n", + "import os\n", + "\n", + "# Third party libraries.\n", + "from dataclasses import dataclass\n", + "\n", + "import nest_asyncio\n", + "from dotenv import find_dotenv, load_dotenv\n", + "from pydantic import BaseModel\n", + "from pydantic_ai import Agent\n", + "from pydantic_ai import ModelRetry\n", + "\n", + "# Local utilities.\n", + "import pydanticai_API_utils as utils\n", + "\n", + "# Notebook-specific imports are ready for tutorial examples." ] }, { - "cell_type": "markdown", - "id": "784a674e", + "cell_type": "code", + "execution_count": 3, + "id": "ef3d968c", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[0mWARNING: Running in Jupyter\n", + "INFO > cmd='/opt/venv/lib/python3.12/site-packages/ipykernel_launcher.py -f /root/.local/share/jupyter/runtime/kernel-9b586c65-e2dd-4a36-8b3c-1ffb500bfed5.json'\n" + ] + } + ], "source": [ - "## PydanticAI API Tutorial Introduction\n", - "\n", - "PydanticAI is a lightweight framework for building LLM-powered applications with **structured outputs using Pydantic models**.\n", - "\n", - "Unlike traditional LLM APIs that return unstructured text, PydanticAI ensures responses conform to a predefined schema.\n", - "\n", - "This notebook covers:\n", + "# Configure notebook logging.\n", + "import logging\n", "\n", - "- Core concepts\n", - "- Agent API\n", - "- Structured outputs\n", - "- Tool usage\n", - "- Validation and retries\n", - "- Async execution\n", + "# Local utility.\n", + "import pydanticai_API_utils as utils\n", "\n", - "By the end, you will understand how to build reliable LLM pipelines using structured outputs." + "_LOG = logging.getLogger(__name__)\n", + "utils.init_logger(_LOG)\n", + "_LOG\n", + "# Notebook logging is configured for the tutorial cells." ] }, { "cell_type": "markdown", - "id": "e6f09140", + "id": "8e2b9ddc", "metadata": {}, "source": [ - "# Table of Contents\n", + "# Summary\n", "\n", - "1. Introduction\n", - "2. Why PydanticAI exists\n", - "3. Installation\n", - "4. Minimal Example\n", - "5. Core Concepts\n", - "6. Structured Outputs\n", - "7. Validation\n", - "8. Tools\n", - "9. Dependencies\n", - "10. Async Execution\n", - "11. Advanced Features\n", - "12. Best Practices\n", - "13. Summary" + "- This notebook introduces `PydanticAI` APIs for building LLM workflows, including structured outputs, tools, dependencies, validators, streaming, provider configuration, run metadata, and usage limits" ] }, { "cell_type": "markdown", - "id": "1d6d6c57", + "id": "784a674e", "metadata": {}, "source": [ - "### Why PydanticAI Exists\n", + "# PydanticAI API Introduction\n", "\n", - "LLMs typically return unstructured text.\n", - "\n", - "Example:\n", - "\n", - "User prompt:\n", - "\"Extract product information from this description\"\n", - "\n", - "LLM output:\n", - "\"The product is an iPhone 15 priced at $999.\"\n", - "\n", - "This output is difficult to use programmatically.\n", - "\n", - "What we want instead:\n", - "\n", - "{\n", - " \"product_name\": \"iPhone 15\",\n", - " \"price\": 999\n", - "}\n", - "\n", - "PydanticAI solves this problem by:\n", - "\n", - "- Defining schemas using **Pydantic models**\n", - "- Enforcing structured outputs\n", - "- Automatically retrying when validation fails\n", - "- Providing a simple agent abstraction for LLM interaction" + "- `PydanticAI` is a lightweight framework for building LLM-powered applications with structured outputs\n", + "- `PydanticAI` uses `Pydantic` models to define response schemas\n", + "- Traditional LLM APIs often return unstructured text\n", + "- `PydanticAI` keeps responses aligned with a predefined schema" ] }, { "cell_type": "markdown", - "id": "9f09d9ed", + "id": "1d6d6c57", "metadata": {}, "source": [ - "### Mental Model\n", - "\n", - "```\n", - "User Prompt\n", - " v\n", - "PydanticAI Agent\n", - " v\n", - "LLM\n", - " v\n", - "Raw Response\n", - " v\n", - "Pydantic Validation\n", - " v\n", - "Structured Output\n", - "```" + "## Why PydanticAI Exists\n", + "\n", + "- Key problem: LLMs typically return unstructured text\n", + "- Example prompt:\n", + " - \"Extract product information from this description\"\n", + "- Example LLM output:\n", + " - \"The product is an iPhone 15 priced at $999.\"\n", + "- Problem with the example LLM output:\n", + " - The example LLM output is difficult to use programmatically\n", + "- Desired structured output:\n", + "\n", + " ```json\n", + " {\n", + " \"product_name\": \"iPhone 15\",\n", + " \"price\": 999\n", + " }\n", + " ```\n", + "\n", + "- `PydanticAI` solves this problem with:\n", + " - Schema definitions with `Pydantic` models\n", + " - Structured output enforcement\n", + " - Automatic retries after validation failures\n", + " - A simple agent abstraction for LLM interaction" ] }, { "cell_type": "markdown", - "id": "64b752ff", + "id": "9f09d9ed", "metadata": {}, "source": [ - "## Installation\n", - "\n", - "We install a minimal set of packages to keep the notebook self-contained and reproducible.\n", - "\n", - "This notebook uses `pydantic-ai`, `pydantic`, and `python-dotenv`.\n" + "## Mental Model\n", + "\n", + "- `PydanticAI` flow:\n", + " ```mermaid\n", + " flowchart TD\n", + " A[User Prompt] --> B[PydanticAI Agent]\n", + " B --> C[LLM]\n", + " C --> D[Raw Response]\n", + " D --> E[Pydantic Validation]\n", + " E --> F[Structured Output]\n", + " ```" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "bba9f441", + "execution_count": 4, + "id": "eff13ce6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "dotenv path: /git_root/tutorials/tutorial_pydanticAI/.env\n" ] + }, + { + "data": { + "text/plain": [ + "'/git_root/tutorials/tutorial_pydanticAI/.env'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "!pip install -q pydantic-ai" + "# Load environment variables from a local dotenv file if one exists.\n", + "env_path = find_dotenv(usecwd=True)\n", + "load_dotenv(env_path, override=True)\n", + "_LOG.info(\"dotenv path: %s\", env_path or \"\")\n", + "env_path or \"\"\n", + "# Environment variables are available to the model configuration cells." ] }, { "cell_type": "code", - "execution_count": 6, - "id": "eff13ce6", + "execution_count": 5, + "id": "9bb08251", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "dotenv path: /curr_dir/.env\n", - "PYDANTIC_AI_MODEL: openai:gpt-5-2025-08-07\n", + "dotenv path: /git_root/tutorials/tutorial_pydanticAI/.env\n", + "PYDANTIC_AI_MODEL: openai:gpt-4.1-mini\n", "OPENAI_API_KEY: sk-...8A\n" ] + }, + { + "data": { + "text/plain": [ + "{'model_id': 'openai:gpt-4.1-mini'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "import os\n", - "from dotenv import load_dotenv, find_dotenv\n", - "import nest_asyncio\n", - "\n", - "nest_asyncio.apply()\n", - "\n", - "\n", - "env_path = find_dotenv(usecwd=True)\n", - "load_dotenv(env_path, override=True)\n", - "\n", - "MODEL_ID = os.getenv(\"PYDANTIC_AI_MODEL\", \"openai:gpt-4.1-mini\")\n", - "print(\"dotenv path:\", env_path or \"\")\n", - "print(\"PYDANTIC_AI_MODEL:\", MODEL_ID)\n", - "print(\"OPENAI_API_KEY:\", utils._mask(os.getenv(\"OPENAI_API_KEY\")))" + "# Read the model identifier from the environment.\n", + "MODEL_ID = os.getenv(\"PYDANTIC_AI_MODEL\")\n", + "utils.log_environment(env_path, MODEL_ID)\n", + "{\"model_id\": MODEL_ID}\n", + "# The tutorial examples will use the configured model identifier." ] }, { "cell_type": "markdown", - "id": "f2e6d162", + "id": "15a7a6e3", "metadata": {}, "source": [ - "### Running the Notebook\n", + "# Core Concepts\n", + "\n", + "- `PydanticAI` revolves around a few important abstractions\n", + "\n", + "## Agent\n", "\n", - "To run the examples you must set your API key.\n", + "- `Agent` is the main interface for interacting with the model\n", + "- `Agent` manages:\n", + " - LLM calls\n", + " - Structured outputs\n", + " - Retries\n", + " - Tool usage\n", + "\n", + "## output_type\n", + "\n", + "- `output_type` defines the expected structured output\n", + "- `output_type` must be a `Pydantic` model\n", + "\n", + "## Tools\n", "\n", - "Example:\n", - "```\n", - "export OPENAI_API_KEY=\"your_key_here\"\n", - "```" + "- Tools are functions that the agent can call during reasoning\n", + "- Tools let agents interact with external systems such as APIs or databases\n", + "\n" ] }, { @@ -217,23 +247,24 @@ "id": "8569d597", "metadata": {}, "source": [ - "## Minimal Example\n", - "\n", - "The quickest way to understand PydanticAI is through a small example.\n", + "# Minimal Example\n", "\n", - "We define a schema using Pydantic and instruct the agent to produce that structured output." + "- The quickest way to understand `PydanticAI` is a small example\n", + "- This section defines a schema with `Pydantic` and asks the agent to produce that structured output" ] }, { "cell_type": "code", "execution_count": 6, - "id": "b7e487b4", - "metadata": {}, + "id": "5d68a76d", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { "data": { "text/plain": [ - "City(name='Paris', country='France', population=2140526)" + "__main__.City" ] }, "execution_count": 6, @@ -242,148 +273,149 @@ } ], "source": [ - "from pydantic import BaseModel\n", - "from pydantic_ai import Agent\n", - "\n", - "\n", + "# Define the output schema for the minimal example.\n", "class City(BaseModel):\n", " name: str\n", " country: str\n", " population: int\n", "\n", "\n", - "agent = Agent(\"openai:gpt-4o-mini\", output_type=City)\n", - "\n", - "result = agent.run_sync(\"Tell me about Paris\")\n", - "\n", - "result.output" + "City\n", + "# The schema defines the exact output shape expected from the model." ] }, { - "cell_type": "markdown", - "id": "86efa23c", + "cell_type": "code", + "execution_count": 7, + "id": "b7e487b4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create an agent that must return `City`.\n", + "agent = Agent(MODEL_ID, output_type=City)\n", + "agent\n", + "# The agent is configured to validate model output against class `City`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5545ded5", "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "This event loop is already running", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Run the minimal example agent.\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m result = agent.run_sync(\u001b[33m\"Tell me about Paris\"\u001b[39m)\n\u001b[32m 3\u001b[39m \n\u001b[32m 4\u001b[39m result.output\n\u001b[32m 5\u001b[39m \u001b[38;5;66;03m# The result is a validated `City` object.\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/opt/venv/lib/python3.12/site-packages/pydantic_ai/agent/abstract.py:509\u001b[39m, in \u001b[36mAbstractAgent.run_sync\u001b[39m\u001b[34m(self, user_prompt, output_type, message_history, deferred_tool_results, model, instructions, deps, model_settings, usage_limits, usage, metadata, infer_name, toolsets, builtin_tools, event_stream_handler, capabilities, spec)\u001b[39m\n\u001b[32m 506\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m infer_name \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 507\u001b[39m \u001b[38;5;28mself\u001b[39m._infer_name(inspect.currentframe())\n\u001b[32m--> \u001b[39m\u001b[32m509\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43m_utils\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mget_event_loop\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mrun_until_complete\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 510\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mrun\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 511\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43muser_prompt\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 512\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43moutput_type\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43moutput_type\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 513\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmessage_history\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mmessage_history\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 514\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdeferred_tool_results\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdeferred_tool_results\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 515\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmodel\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mmodel\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 516\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43minstructions\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43minstructions\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 517\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdeps\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdeps\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 518\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmodel_settings\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mmodel_settings\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 519\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43musage_limits\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43musage_limits\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 520\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43musage\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43musage\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 521\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmetadata\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mmetadata\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 522\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43minfer_name\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 523\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtoolsets\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtoolsets\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 524\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mbuiltin_tools\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mbuiltin_tools\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 525\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mevent_stream_handler\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mevent_stream_handler\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 526\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mcapabilities\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mcapabilities\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 527\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mspec\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mspec\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 528\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 529\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.12/asyncio/base_events.py:667\u001b[39m, in \u001b[36mBaseEventLoop.run_until_complete\u001b[39m\u001b[34m(self, future)\u001b[39m\n\u001b[32m 656\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"Run until the Future is done.\u001b[39;00m\n\u001b[32m 657\u001b[39m \n\u001b[32m 658\u001b[39m \u001b[33;03mIf the argument is a coroutine, it is wrapped in a Task.\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 664\u001b[39m \u001b[33;03mReturn the Future's result, or raise its exception.\u001b[39;00m\n\u001b[32m 665\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 666\u001b[39m \u001b[38;5;28mself\u001b[39m._check_closed()\n\u001b[32m--> \u001b[39m\u001b[32m667\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_check_running\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 669\u001b[39m new_task = \u001b[38;5;129;01mnot\u001b[39;00m futures.isfuture(future)\n\u001b[32m 670\u001b[39m future = tasks.ensure_future(future, loop=\u001b[38;5;28mself\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/usr/local/lib/python3.12/asyncio/base_events.py:626\u001b[39m, in \u001b[36mBaseEventLoop._check_running\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 624\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_check_running\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 625\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_running():\n\u001b[32m--> \u001b[39m\u001b[32m626\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33m'\u001b[39m\u001b[33mThis event loop is already running\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 627\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m events._get_running_loop() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 628\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m 629\u001b[39m \u001b[33m'\u001b[39m\u001b[33mCannot run the event loop while another loop is running\u001b[39m\u001b[33m'\u001b[39m)\n", + "\u001b[31mRuntimeError\u001b[39m: This event loop is already running" + ] + } + ], "source": [ - "### What Happened?\n", + "# Run the minimal example agent.\n", + "result = agent.run_sync(\"Tell me about Paris\")\n", "\n", - "1. A Pydantic schema (`City`) defines the expected output structure.\n", - "2. The `Agent` sends the prompt to the LLM.\n", - "3. The LLM response is validated against the schema.\n", - "4. If validation succeeds, the structured result is returned." + "result.output\n", + "# The result is a validated `City` object." ] }, { "cell_type": "markdown", - "id": "15a7a6e3", + "id": "ba8f4833-dacb-435c-8bc8-1daeb718262e", "metadata": {}, "source": [ - "## Core Concepts\n", - "\n", - "PydanticAI revolves around a few important abstractions.\n", - "\n", - "### Agent\n", - "\n", - "The `Agent` is the main interface for interacting with the model.\n", - "\n", - "It manages:\n", - "\n", - "- LLM calls\n", - "- structured outputs\n", - "- retries\n", - "- tool usage\n", + "# Resolving the Above RuntimeError in Jupyter\n", "\n", - "### output_type\n", - "\n", - "Defines the expected structured output.\n", - "\n", - "This must be a Pydantic model.\n", - "\n", - "### Tools\n", - "\n", - "Functions that the agent can call during reasoning.\n", - "\n", - "Tools allow agents to interact with external systems such as APIs or databases.\n", - "\n" + "- Key thing to remember: Jupyter already runs an active event loop" ] }, { "cell_type": "markdown", - "id": "e0f3aa76", + "id": "ce72edf2-d1f4-4d60-ac36-29680d884d9a", "metadata": {}, "source": [ - "## Structured Outputs with Pydantic" + "- `agent.run_sync()` can raise a `RuntimeError` in notebook environments\n", + "- `nest_asyncio` patches the notebook event loop so nested async execution can work\n", + "- After `nest_asyncio.apply()`, async `PydanticAI` examples can run inside notebook cells" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "895da5b2", - "metadata": { - "lines_to_next_cell": 2 - }, + "execution_count": 9, + "id": "bba9f441", + "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nested event loop support enabled.\n" + ] + }, { "data": { "text/plain": [ - "Product(name='Apple AirPods Pro', price=249.0, category='Electronics')" + "True" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from pydantic import BaseModel\n", - "\n", - "\n", - "class Product(BaseModel):\n", - " name: str\n", - " price: float\n", - " category: str\n", - "\n", - "\n", - "agent = Agent(\"openai:gpt-4o-mini\", output_type=Product)\n", - "\n", - "agent.run_sync(\"Describe the Apple AirPods Pro\").output" + "# Enable nested event loops for notebook execution.\n", + "nest_asyncio.apply()\n", + "nested_event_loop_enabled = True\n", + "_LOG.info(\"Nested event loop support enabled.\")\n", + "nested_event_loop_enabled\n", + "# Async PydanticAI examples can now run from notebook cells." ] }, { "cell_type": "markdown", - "id": "d8d15d06-6d82-42cf-b003-7b85cf45eb2d", + "id": "46db7bd2-16ae-46ec-8b03-361b80a9aa40", "metadata": {}, "source": [ - "### What happened in the code\n", - "\n", - "- We defined a `Product` schema (name, price, category).\n", - "- The agent is configured to produce outputs that conform to this schema.\n", - "- When the model answers, PydanticAI validates that:\n", - " - `price` is a number\n", - " - fields exist with the right types\n", - " - the structure matches exactly\n", - "\n", - "**Why PydanticAI is useful here:**\n", - "This turns LLM responses into structured data you can store in databases, feed into analytics, or pass downstream in an application without brittle string parsing." + "- Re-run the previous cell that raised the `RuntimeError`" ] }, { "cell_type": "markdown", - "id": "5716df9d", - "metadata": { - "lines_to_next_cell": 2 - }, + "id": "e0f3aa76", + "metadata": {}, "source": [ - "## Validation and Retries\n", + "# Structured Outputs with Pydantic\n", "\n", - "If the LLM produces an output that does not match the schema, PydanticAI automatically retries.\n", - "\n", - "This greatly improves reliability." + "- `PydanticAI` turns LLM responses into structured data\n", + "- Structured outputs help you:\n", + " - Store validated outputs in databases\n", + " - Feed typed objects into analytics\n", + " - Pass structured data downstream without brittle string parsing" ] }, { "cell_type": "code", - "execution_count": 8, - "id": "775f32dd", + "execution_count": 10, + "id": "636df5ab", "metadata": { "lines_to_next_cell": 2 }, @@ -391,191 +423,318 @@ { "data": { "text/plain": [ - "AgentRunResult(output=Person(name='Albert Einstein', age=76))" + "__main__.Product" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "class Person(BaseModel):\n", + "# Define a product schema for structured extraction.\n", + "class Product(BaseModel):\n", " name: str\n", - " age: int\n", - "\n", - "\n", - "agent = Agent(\"openai:gpt-4o-mini\", output_type=Person, retries=2)\n", - "\n", - "agent.run_sync(\"Tell me about Albert Einstein\")" - ] - }, - { - "cell_type": "markdown", - "id": "c55fb759-4d81-4a4d-8899-9759b8d82f27", - "metadata": {}, - "source": [ - "### What happened in the code\n", + " price: float\n", + " category: str\n", "\n", - "- We defined a `Person` schema with `name` and `age`.\n", - "- We set `retries=2` on the agent.\n", - "- If the model output fails schema validation (missing fields, wrong types), PydanticAI automatically retries the model call to get a valid output.\n", "\n", - "**Why PydanticAI is useful here:**\n", - "Real LLM outputs are inconsistent. Automatic schema validation + retry gives you reliability without writing custom parsing and retry logic for every prompt." + "Product\n", + "# The schema captures the product fields we want to extract." ] }, { - "cell_type": "markdown", - "id": "3948bb6c", + "cell_type": "code", + "execution_count": 11, + "id": "895da5b2", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "## Tools\n", - "\n", - "Agents can call Python functions as tools." + "# Create an agent that must return `Product`.\n", + "agent = Agent(MODEL_ID, output_type=Product)\n", + "agent\n", + "# The agent is configured to return product data with typed fields." ] }, { "cell_type": "code", - "execution_count": 9, - "id": "099d9d99", + "execution_count": 12, + "id": "9b141b60", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AgentRunResult(output='The weather in Tokyo is sunny.')" + "Product(name='Apple AirPods Pro', price=249.0, category='Electronics/Audio')" ] }, - "execution_count": 9, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "agent = Agent(\"openai:gpt-4o-mini\", tools=[utils.get_weather])\n", - "\n", - "agent.run_sync(\"What is the weather in Tokyo?\")" + "# Ask the model for structured product information.\n", + "agent.run_sync(\"Describe the Apple AirPods Pro\").output\n", + "# The response is validated as a `Product` class object." ] }, { "cell_type": "markdown", - "id": "57381ed0-cf9f-467c-8437-d3858c7b29a7", + "id": "5716df9d", "metadata": {}, "source": [ - "### What happened in the code\n", - "\n", - "- We defined a Python function `get_weather(city)` that returns a deterministic string.\n", - "- We passed it into the agent via `tools=[get_weather]`.\n", - "- When the user asks about weather, the agent can choose to call the tool to get the answer instead of hallucinating.\n", + "# Validation and Retries\n", "\n", - "**Why PydanticAI is useful here:**\n", - "Tools let the model interact with real functions and external systems. This is how you build agents that do real work (APIs, databases, calculations) rather than confidently inventing facts." + "- Real LLM outputs are inconsistent\n", + "- Schema validation checks the generated structure\n", + "- Retries let `PydanticAI` ask the model to repair invalid output\n", + "- This notebook avoids custom parsing and retry logic in each prompt" ] }, { - "cell_type": "markdown", - "id": "6bbc710d", - "metadata": {}, + "cell_type": "code", + "execution_count": 13, + "id": "4b256f36", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "__main__.Person" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "## Dependencies\n", + "# Define a schema that requires an integer age.\n", + "class Person(BaseModel):\n", + " name: str\n", + " age: int\n", + "\n", "\n", - "Dependencies allow agents to access external resources or shared state." + "Person\n", + "# The schema enforces integer typing for age values." ] }, { "cell_type": "code", - "execution_count": 10, - "id": "772c04ee", + "execution_count": 14, + "id": "775f32dd", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "The configured company is OpenAI.\n" - ] - } - ], + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from dataclasses import dataclass\n", - "from pydantic_ai import Agent\n", - "\n", - "\n", - "@dataclass\n", - "class Config:\n", - " company: str\n", - "\n", - "\n", - "agent = Agent(\"openai:gpt-4o-mini\", deps_type=Config, tools=[utils.company_name])\n", - "\n", - "result = agent.run_sync(\n", - " \"What company is configured?\", deps=Config(company=\"OpenAI\")\n", - ")\n", - "print(result.output)" + "# Configure retries so schema validation failures can be corrected.\n", + "agent = Agent(MODEL_ID, output_type=Person, retries=2)\n", + "agent\n", + "# The agent can retry when model output does not match `Person`." ] }, { - "cell_type": "markdown", - "id": "9c263739-f6e0-4cb7-ae54-15b9f6e87a9d", + "cell_type": "code", + "execution_count": 15, + "id": "5d8126e2", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AgentRunResult(output=Person(name='Albert Einstein', age=76))" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "### What happened in the code\n", - "\n", - "- `deps_type=Config` declares the *shape* of runtime context the agent can receive.\n", - "- At run time, we pass an instance like `Config(company=\"OpenAI\")`.\n", - "- Tools (or other agent logic) can access this via `RunContext.deps`, so the agent can use configuration/state without hardcoding it into prompts.\n", - "\n", - "**Why PydanticAI is useful here:**\n", - "Dependencies are a clean way to inject runtime configuration (tenant ID, API clients, feature flags, environment context) into agents and tools without relying on global variables or string formatting prompts." + "# Run the retry-enabled agent.\n", + "agent.run_sync(\"Tell me about Albert Einstein\")\n", + "# The result is a validated `Person` run result." ] }, { "cell_type": "markdown", - "id": "6c1d10c1", + "id": "3948bb6c", "metadata": {}, "source": [ - "## Async Execution\n", + "# Tools\n", "\n", - "PydanticAI supports asynchronous execution for scalable applications." + "- Agents can call Python functions as tools\n", + "- Tools let the model interact with real functions and external systems\n", + "- Tools are useful for APIs, databases, calculations, and deterministic helpers\n", + "- Tool calls reduce the chance that the model invents facts" ] }, { "cell_type": "code", - "execution_count": 11, - "id": "b9bf9835", + "execution_count": 16, + "id": "099d9d99", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'Tokyo, the capital city of Japan, is a vibrant metropolis known for its blend of traditional culture and modern innovation. Here are some key highlights about Tokyo:\\n\\n1. **Geography**: Located on the eastern coast of Honshu, Tokyo is situated in the Kanto region. It is part of the Tokyo Metropolis, which includes 23 special wards, and is surrounded by the Saitama, Chiba, and Kanagawa prefectures.\\n\\n2. **Population**: Tokyo is one of the most populous cities in the world, with a metropolitan area that has over 37 million residents, making it a major hub for business and culture.\\n\\n3. **Culture and History**: Tokyo was originally a small fishing village named Edo. It became the political center of Japan in the early 17th century when Tokugawa Ieyasu, the founder of the Tokugawa shogunate, established his government there. The city was renamed Tokyo, meaning \"Eastern Capital,\" in 1868.\\n\\n4. **Architecture and Urban Design**: Tokyo is known for its eclectic architecture, featuring a mix of traditional structures (like temples and shrines) and modern skyscrapers. The Tokyo Tower and the Tokyo Skytree are two iconic landmarks that symbolize the city’s skyline.\\n\\n5. **Transport**: Tokyo has one of the most efficient public transportation systems in the world, including an extensive network of trains, subways, and buses. The Tokyo Metro and JR East train services are particularly notable for their punctuality and coverage.\\n\\n6. **Economy**: As a global financial center, Tokyo hosts numerous multinational corporations and is a leading city in technology, manufacturing, and commerce. The Tokyo Stock Exchange is one of the largest stock exchanges in the world.\\n\\n7. **Cuisine**: Tokyo boasts a rich culinary scene, offering everything from sushi and ramen to high-end dining experiences. It has more Michelin-starred restaurants than any other city in the world.\\n\\n8. **Tourist Attractions**: Popular destinations in Tokyo include the historic Senso-ji Temple in Asakusa, the busy shopping districts of Shibuya and Harajuku, the Imperial Palace, and the vibrant nightlife of Shinjuku.\\n\\n9. **Arts and Entertainment**: Tokyo is a cultural hub, known for its museums, art galleries, theaters, and music venues. Events like the Tokyo Anime and Comic Market celebrate Japan’s pop culture.\\n\\n10. **Parks and Nature**: Despite being a bustling urban environment, Tokyo offers several green spaces, including Ueno Park and the picturesque Shinjuku Gyoen National Garden, where residents and visitors can enjoy nature.\\n\\nTokyo\\'s unique blend of the old and new makes it a fascinating destination both for residents and tourists alike.'" + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" ] }, - "execution_count": 11, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import asyncio\n", - "\n", - "asyncio.run(utils.run_agent(agent))" + "# Create an agent with a deterministic weather tool.\n", + "agent = Agent(MODEL_ID, tools=[utils.get_weather])\n", + "agent\n", + "# The agent can call `utils.get_weather()` while answering." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3a58783d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AgentRunResult(output='The weather in Tokyo is sunny.')" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Ask a question that should use the weather tool.\n", + "agent.run_sync(\"What is the weather in Tokyo?\")\n", + "# The run result includes the tool-backed weather answer." ] }, { "cell_type": "markdown", - "id": "c41412c5-70b4-44c1-bdb8-9c98da932144", + "id": "6bbc710d", "metadata": {}, "source": [ - "### What happened in the code\n", + "# Dependencies\n", + "\n", + "- Dependencies inject runtime context into agents and tools\n", + "- Example dependency values:\n", + " - Tenant IDs\n", + " - API clients\n", + " - Feature flags\n", + " - Environment context\n", + "- Dependencies let tools access context without global variables or prompt string formatting" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "8ffc2657", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "__main__.Config" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define the dependency object passed into the agent at run time.\n", + "@dataclass\n", + "class Config:\n", + " company: str\n", "\n", - "- We defined an async function that calls `await agent.run(...)`.\n", - "- Async execution is helpful for applications that need concurrency (web servers, batch pipelines, background jobs).\n", - "- `asyncio.run(...)` runs the coroutine in a notebook-safe way.\n", "\n", - "**Why PydanticAI is useful here:**\n", - "Most real systems are async. PydanticAI supports async natively, so you can run many agent calls concurrently without blocking your app." + "Config\n", + "# The dependency schema describes runtime context available to tools." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "772c04ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create an agent that receives `Config` dependencies.\n", + "# `deps_type=Config` declares the shape of runtime context the agent can receive.\n", + "agent = Agent(MODEL_ID, deps_type=Config, tools=[utils.company_name])\n", + "agent\n", + "# Tools can access `Config` through the PydanticAI run context." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "1b9e4981", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'The configured company is OpenAI.'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run the dependency-aware agent with a concrete configuration.\n", + "result = agent.run_sync(\n", + " \"What company is configured?\", deps=Config(company=\"OpenAI\")\n", + ")\n", + "result.output\n", + "# The answer reflects the runtime dependency value." ] }, { @@ -583,19 +742,16 @@ "id": "9968fba5", "metadata": {}, "source": [ - "## Advanced API Features\n", - "\n", - "The following sections demonstrate more advanced capabilities of PydanticAI.\n", - "\n", - "These features are useful when building production-grade systems:\n", - "\n", - "- custom validation\n", - "- streaming outputs\n", - "- model configuration\n", - "- usage tracking\n", - "- runtime limits\n", - "\n", - "Beginners can safely skip this section on a first read." + "# Advanced Features\n", + "\n", + "- The following sections demonstrate more advanced `PydanticAI` capabilities\n", + "- These features are useful for production-grade systems:\n", + " - Custom validation\n", + " - Streaming outputs\n", + " - Model configuration\n", + " - Usage tracking\n", + " - Runtime limits\n", + "- Beginners can safely skip this section on a first read" ] }, { @@ -603,42 +759,60 @@ "id": "1ec1cef2", "metadata": {}, "source": [ - "## Result Validators\n", - "\n", - "Result validators allow you to enforce additional rules on model outputs.\n", - "\n", - "Even if the response matches the Pydantic schema, we may still want to verify\n", - "logical constraints.\n", - "\n", - "Example: if an answer claims to use documents, it must include at least one source." + "# Result Validators\n", + "\n", + "- Result validators are used to check model outputs after schema validation\n", + "- `Pydantic` validates structure automatically, but result validators enforce business rules\n", + "- A response can match the `Pydantic` schema and still fail logical constraints\n", + "- For example, this output may be valid according to the schema:\n", + " - it has an `answer`\n", + " - it has a `sources` list\n", + "- But it can still be logically wrong if:\n", + " - the source list is empty\n", + " - the `doc_id` does not exist\n", + " - the quote does not actually appear in the cited document\n", + "\n", + "- Result validators handle this second layer of validation" + ] + }, + { + "cell_type": "markdown", + "id": "6f49d16c-71a4-4d5b-9cfd-7149cdcad70f", + "metadata": {}, + "source": [ + "## Validation Flow\n", + "\n", + "- Validation happens in two stages:\n", + " - `Schema validation`: the model output must match `AnswerWithSources`\n", + " - `Business-rule validation`: the registered `output_validator` enforces citation quality rules that schema alone cannot enforce\n", + "- Execution order:\n", + " ```mermaid\n", + " flowchart LR\n", + " A[Model Output] --> B[Pydantic Schema Validation]\n", + " B --> C[output_validator]\n", + " C --> D[Final Result]\n", + " ```" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 21, "id": "c66c4d20", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - " Any>" + "__main__.AnswerWithSources" ] }, - "execution_count": 12, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from pydantic import BaseModel\n", - "from pydantic_ai import Agent\n", - "\n", - "MODEL_ID = \"openai:gpt-4o-mini\"\n", - "\n", - "\n", + "# Define source citation schemas with explicit references for validator examples.\n", "class SourceRef(BaseModel):\n", " doc_id: str\n", " quote: str\n", @@ -649,93 +823,237 @@ " sources: list[SourceRef]\n", "\n", "\n", + "AnswerWithSources\n", + "# The schemas describe answers that include source citations." + ] + }, + { + "cell_type": "markdown", + "id": "491a9aed-b118-45de-90c8-c37ff9256454", + "metadata": {}, + "source": [ + "## Prepare Validation Context\n", + "\n", + "- We fetch the list of valid document IDs and include it in the agent instructions\n", + "- This helps:\n", + " - reduce hallucinated references\n", + " - constrain the model to known documents" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "96f6f0ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'available_doc_ids': ['api',\n", + " 'billing',\n", + " 'integrations',\n", + " 'limits',\n", + " 'overview',\n", + " 'security',\n", + " 'support',\n", + " 'troubleshooting'],\n", + " 'validator_instruction_length': 260}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Build validator instructions from local document ids.\n", + "available_doc_ids = utils.get_available_document_ids()\n", + "# Build instructions that restrict citations to the local dataset.\n", + "validator_instructions = (\n", + " \"Use the search_documents tool to retrieve evidence from local documents. \"\n", + " f\"Cite only these doc ids: {available_doc_ids}. \"\n", + " \"For each source, copy the quote text exactly from tool output.\"\n", + ")\n", + "{\n", + " \"available_doc_ids\": available_doc_ids,\n", + " \"validator_instruction_length\": len(validator_instructions),\n", + "}\n", + "# The instructions constrain citations to the local document ids." + ] + }, + { + "cell_type": "markdown", + "id": "88447504-7d02-4fbb-bfd2-1b043870b3f2", + "metadata": {}, + "source": [ + "### Create the Validator Agent\n", + "- This agent:\n", + " - generates structured output\n", + " - retrieves documents using a tool\n", + " - follows constrained citation rules\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "8a0e840b", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create an agent that returns answers with source references.\n", + "# The agent uses structured output plus the local document-search tool.\n", "validator_agent = Agent(\n", " MODEL_ID,\n", " output_type=AnswerWithSources,\n", - " instructions=(\n", - " \"Answer with short factual statements. \"\n", - " \"If you reference documents, include sources.\"\n", - " ),\n", + " instructions=validator_instructions,\n", + " tools=[utils.search_documents],\n", ")\n", - "validator_agent.output_validator(utils.validate_sources)" + "validator_agent\n", + "# The validator agent can retrieve documents and return cited answers." ] }, { - "cell_type": "code", - "execution_count": 13, - "id": "975c50ca-65ae-4838-8d44-599fee1d461f", + "cell_type": "markdown", + "id": "67b3e81d-aeb7-4e0a-a52c-8a34077e7d09", "metadata": {}, + "source": [ + "## Add Result Validator\n", + "\n", + "- The `@output_validator` runs after schema validation and enforces business rules:\n", + " - sources must be present\n", + " - document IDs must exist\n", + " - quotes must match source documents\n", + " - duplicates are not allowed\n", + "- If validation fails, `ModelRetry` is raised, and the model is asked to generate a corrected answer." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "1a6d9743", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validator failure example: Answer references documents but sources are empty.\n" - ] + "data": { + "text/plain": [ + "{'validator_registered': True}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "try:\n", - " utils.validate_sources(\n", - " AnswerWithSources(answer=\"According to the documents...\", sources=[])\n", - " )\n", - "except Exception as e:\n", - " print(\"Validator failure example:\", e)" + "# Register a result validator that checks citations against local documents.\n", + "@validator_agent.output_validator\n", + "def _validate_answer_sources(\n", + " result: AnswerWithSources,\n", + ") -> AnswerWithSources:\n", + " # Validate citations against the local document dataset.\n", + " validated_result = utils.validate_document_sources(result)\n", + " return validated_result\n", + "\n", + "\n", + "{\"validator_registered\": True}\n", + "# The validator agent now enforces schema and source-reference rules." ] }, { "cell_type": "markdown", - "id": "7899e03f-bd34-4e97-b345-2cf206a33de0", + "id": "8911b1db", "metadata": {}, "source": [ - "### What happened in the code\n", + "## Manual Failure Example\n", "\n", - "- We defined a schema `AnswerWithSources` where the model must return:\n", - " - `answer` (string)\n", - " - `sources` (list of `{doc_id, quote}`)\n", - "- We attached an `output_validator` that enforces *logical rules* beyond the schema:\n", - " - if the answer mentions docs, sources must not be empty\n", - " - max 3 sources\n", - " - no duplicate sources\n", - "- If rules fail, we raise `ModelRetry`, which tells PydanticAI to retry the model call.\n", - "\n", - "**Why PydanticAI is useful here:**\n", - "Schemas catch structural mistakes. Validators catch logical mistakes. Together, they make LLM outputs production-grade by enforcing business rules automatically." + "- We intentionally create an invalid output to demonstrate how the validator triggers a retry.\n", + "- This example bypasses the model and directly tests the validator logic." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2d040ab3-0e5c-470b-8454-916a6ce69d65", + "metadata": {}, + "outputs": [], + "source": [ + "# Build an invalid answer object for the validator demo.\n", + "bad_answer = AnswerWithSources(\n", + " answer=\"PydanticAI supports structured outputs.\",\n", + " sources=[],\n", + ")\n", + "bad_answer\n", + "# The invalid answer is missing source citations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "validator-failure-trigger", + "metadata": {}, + "outputs": [], + "source": [ + "# Trigger the validator on the intentionally invalid answer.\n", + "_LOG.info(\"Triggering the validator with an intentionally invalid answer.\")\n", + "_validate_answer_sources(bad_answer)\n", + "# The validator raises `ModelRetry` for the missing sources." ] }, { "cell_type": "markdown", - "id": "df790772-6554-4e41-b21c-626d73c8ad79", + "id": "aa66d61a-a316-4190-bf18-6cb69a65cc9e", "metadata": {}, "source": [ - "### Validator Failure Example\n", - "\n", - "The validator can also be tested manually.\n", - "\n", - "If the validation rule fails, the validator raises `ModelRetry`, which instructs the agent to retry the LLM call with improved instructions." + "## Run the Agent\n", + "\n", + "- The agent will:\n", + " - Generate structured output\n", + " \n", + " - Validate it against the schema\n", + " \n", + " - Apply business rules\n", + " \n", + " - Retry automatically if validation fails" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "1d332ae9-b4de-4501-84c8-3cea4fa772a9", + "execution_count": 27, + "id": "29534576-f630-4009-9bf2-d12d3a4cacfe", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Validated output:\n", - "\n", - "answer='The greenhouse effect is a natural process that warms the Earth’s surface. When the Sun\\'s energy reaches the Earth, some of it is reflected back to space and the rest is absorbed, warming the planet. The absorbed energy is then re-radiated as infrared energy (heat). Greenhouse gases—such as carbon dioxide, methane, and water vapor—trap some of this heat in the atmosphere, preventing it from escaping back into space, which keeps the Earth warm enough to support life.\\n\\nIncreased levels of these gases due to human activities, such as burning fossil fuels and deforestation, enhance the greenhouse effect, leading to global warming and climate change. \\n\\nSources:\\n1. Intergovernmental Panel on Climate Change (IPCC) - \"Climate Change 2021: The Physical Science Basis\".\\n2. National Aeronautics and Space Administration (NASA) - \"The Greenhouse Effect\".' sources=[SourceRef(doc_id='1', quote='Greenhouse gases trap heat in the atmosphere, preventing it from escaping back into space.'), SourceRef(doc_id='2', quote='Increased levels of these gases due to human activities lead to global warming.')]\n" - ] + "data": { + "text/plain": [ + "AnswerWithSources(answer='Atlas billing plans include Team and Enterprise plans, which offer features such as two-factor authentication (2FA). Billing details such as invoices can be managed and downloaded through the Settings > Billing section in the Atlas interface. Specific pricing or other plan tiers are not detailed in the provided documents. For exact billing options and plan details, accessing your Atlas settings or contacting support would be recommended.', sources=[SourceRef(doc_id='security', quote='Atlas supports two-factor authentication (2FA) for Team and Enterprise plans.'), SourceRef(doc_id='billing', quote='- You can download invoices from Settings > Billing.')])" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "import asyncio\n", - "\n", - "asyncio.run(utils.run_validator_example(validator_agent))" + "# Run the validator agent with the local document search tool.\n", + "validator_result = asyncio.run(utils.run_validator_example(validator_agent))\n", + "validator_result\n", + "# The validator agent returns a cited answer that passed validation." ] }, { @@ -743,75 +1061,70 @@ "id": "7828a0ab", "metadata": {}, "source": [ - "## Streaming\n", - "\n", - "Streaming allows tokens to be returned as they are generated.\n", + "# Streaming\n", "\n", - "Benefits:\n", - "\n", - "- lower perceived latency\n", - "- better user experience in chat interfaces\n", - "- progressive display of responses" + "- Streaming returns tokens as the model generates them\n", + "- Streaming benefits:\n", + " - Lower perceived latency\n", + " - Better user experience in chat interfaces\n", + " - Progressive display of responses" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 39, "id": "7fbec717", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Streaming:\n", - "Unit tests are a type of software testingUnit tests are a type of software testing that focuses on verifying the correctness of individualUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolationUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these testsUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—Unit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring thatUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions.Unit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identifyUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identify bugs early in the development processUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identify bugs early in the development process, promote code reliability, andUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identify bugs early in the development process, promote code reliability, and facilitate easier debugging and future codeUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identify bugs early in the development process, promote code reliability, and facilitate easier debugging and future code changes by providing a safety netUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identify bugs early in the development process, promote code reliability, and facilitate easier debugging and future code changes by providing a safety net that confirms existing functionality remains intact. ByUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identify bugs early in the development process, promote code reliability, and facilitate easier debugging and future code changes by providing a safety net that confirms existing functionality remains intact. By automating these tests, teams can increaseUnit tests are a type of software testing that focuses on verifying the correctness of individual components or functions of a program in isolation. Typically written by developers, these tests evaluate the smallest parts of the software—usually functions or methods—ensuring that they behave as expected under various conditions. The goal of unit testing is to identify bugs early in the development process, promote code reliability, and facilitate easier debugging and future code changes by providing a safety net that confirms existing functionality remains intact. By automating these tests, teams can increase efficiency and streamline the development workflow.---\n", - "Streaming failed; falling back to run(). 'StreamedRunResult' object has no attribute 'get_final_result'\n", - "\n", - "\n", - "Non-streamed: AgentRunResult(output='Unit tests are a type of software testing that focuses on validating individual components or functions of a program in isolation to ensure they behave as expected. By testing small, discrete sections of code—usually at the level of functions or methods—developers can identify and fix bugs early in the development process, thereby enhancing code reliability and maintainability. Unit tests are typically automated and executed frequently, allowing for rapid feedback and promoting confidence in the stability of the codebase as it evolves.')\n" - ] + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ + "# Create an agent for the streaming example.\n", "stream_agent = Agent(\n", " MODEL_ID, instructions=\"Write one short paragraph about unit tests.\"\n", ")\n", - "\n", - "if not hasattr(stream_agent, \"run_stream\"):\n", - " print(\"Streaming API not available; falling back to run().\")\n", - " result = await stream_agent.run(\"What are unit tests?\")\n", - " _print_result(\"Non-streamed:\", result)\n", - "else:\n", - " try:\n", - " async with stream_agent.run_stream(\"What are unit tests?\") as stream:\n", - " print(\"Streaming:\")\n", - " async for chunk in stream.stream_text():\n", - " print(chunk, end=\"\", flush=True)\n", - " print(\"---\")\n", - " result = await stream.get_final_result()\n", - " print(\"\\n\\nFinal result:\", result)\n", - " except Exception as e:\n", - " print(\"Streaming failed; falling back to run().\", e)\n", - " result = await stream_agent.run(\"What are unit tests?\")\n", - " print(\"\\n\\nNon-streamed:\", result)" + "stream_agent\n", + "# The streaming agent is ready to produce incremental text." ] }, { - "cell_type": "markdown", - "id": "fced043d-9333-42d2-8b3d-0f80b1ed1c7b", + "cell_type": "code", + "execution_count": 40, + "id": "5a4a5245", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Streaming output:\n", + "Unit tests are automated tests that verify the functionality of individual components or units of code, such as functions or methods, in isolation from the rest of the application. Their primary purpose is to ensure that each unit performs as expected, helping developers catch bugs early, improve code quality, and simplify maintenance. By running unit tests frequently during development, teams can identify issues quickly and confidently make changes without introducing new errors.\n" + ] + }, + { + "data": { + "text/plain": [ + "'Unit tests are automated tests that verify the functionality of individual components or units of code, such as functions or methods, in isolation from the rest of the application. Their primary purpose is to ensure that each unit performs as expected, helping developers catch bugs early, improve code quality, and simplify maintenance. By running unit tests frequently during development, teams can identify issues quickly and confidently make changes without introducing new errors.'" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "### What happened in the code\n", - "\n", - "- We created an agent and attempted to call the model using streaming mode.\n", - "- With streaming, tokens are yielded as the model generates them instead of waiting for the full response.\n", - "- This improves perceived responsiveness for chat apps and UIs.\n", - "\n", - "**Why PydanticAI is useful here:**\n", - "Streaming helps build better user experiences. You can display partial output instantly while the model continues generating, which is critical for interactive assistants." + "# Run the streaming helper and return the final result.\n", + "asyncio.run(utils.run_streaming_demo(stream_agent))\n", + "# The helper logs streamed text and returns the final result." ] }, { @@ -819,92 +1132,74 @@ "id": "52c6072a", "metadata": {}, "source": [ - "## Provider Configuration\n", - "\n", - "Model objects let you configure providers directly (e.g., base URLs).\n", + "# Provider Configuration\n", "\n", - "You can supply an explicit model object instead of a string ID. This is where you would set provider-specific options (e.g., `base_url`).\n" + "- Model objects let you configure providers directly, such as `base_url`\n", + "- Use an explicit model object when provider-specific options are needed\n" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 41, "id": "c6e3973b", - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Explicit model unavailable; using string model ID. OpenAIChatModel.__init__() got an unexpected keyword argument 'model'\n" + "Using OpenAI model with model_name='gpt-4.1-mini'.\n", + "Using explicit model object.\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_24/2437657520.py:4: DeprecationWarning: `OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which uses OpenAI's newer Responses API. Use that unless you're using an OpenAI Chat Completions-compatible API, or require a feature that the Responses API doesn't support yet like audio.\n", - " explicit_model = OpenAIModel(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Explicit model (or fallback): AgentRunResult(output='Hello! How can I assist you today?')\n" - ] + "data": { + "text/plain": [ + "{'explicit_model_available': True}" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "explicit_model = None\n", - "try:\n", - " from pydantic_ai.models.openai import OpenAIModel\n", - "\n", - " explicit_model = OpenAIModel(\n", - " model=MODEL_ID.split(\":\", 1)[-1],\n", - " api_key=os.getenv(\"OPENAI_API_KEY\"),\n", - " base_url=os.getenv(\"OPENAI_BASE_URL\"),\n", - " )\n", - " print(\"Using explicit OpenAIModel.\")\n", - "except Exception:\n", - " try:\n", - " from pydantic_ai.models.openai import OpenAIChatModel\n", - "\n", - " explicit_model = OpenAIChatModel(\n", - " model=MODEL_ID.split(\":\", 1)[-1],\n", - " api_key=os.getenv(\"OPENAI_API_KEY\"),\n", - " base_url=os.getenv(\"OPENAI_BASE_URL\"),\n", - " )\n", - " print(\"Using explicit OpenAIChatModel.\")\n", - " except Exception as e2:\n", - " print(\"Explicit model unavailable; using string model ID.\", e2)\n", - "\n", - "agent = Agent(explicit_model or MODEL_ID, instructions=\"Be concise.\")\n", - "try:\n", - " result = await agent.run(\"Say hello in one sentence.\")\n", - " print(\"Explicit model (or fallback):\", result)\n", - "except Exception as e:\n", - " print(\"Error: \", e)" + "# Build an explicit provider model object when the installed API supports it.\n", + "explicit_model = utils.build_explicit_openai_model(MODEL_ID)\n", + "# Log which provider configuration path is active.\n", + "if explicit_model is None:\n", + " _LOG.info(\"Explicit model unavailable; using string model ID.\")\n", + "else:\n", + " _LOG.info(\"Using explicit model object.\")\n", + "{\"explicit_model_available\": explicit_model is not None}\n", + "# Provider configuration is either explicit or falls back to `MODEL_ID`." ] }, { - "cell_type": "markdown", - "id": "867a9074-dc6b-435d-b2ee-ff41eb7ce217", - "metadata": {}, + "cell_type": "code", + "execution_count": 42, + "id": "6b8fc187", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AgentRunResult(output='Hello! How can I assist you today?')" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "### What happened in the code\n", - "\n", - "- Instead of using a string model ID, we attempted to create an explicit provider model object.\n", - "- This allows provider-specific configuration such as:\n", - " - custom base URLs\n", - " - custom API keys\n", - " - proxy settings\n", - "- If explicit model classes aren't available in the installed version, we fall back to using the string model ID.\n", - "\n", - "**Why PydanticAI is useful here:**\n", - "Explicit provider configuration is what you use in real deployments: enterprise gateways, self-hosted endpoints, proxies, and custom routing." + "# Run an agent with the explicit provider model when available.\n", + "agent = Agent(explicit_model or MODEL_ID, instructions=\"Be concise.\")\n", + "result = asyncio.run(agent.run(\"Say hello in one sentence.\"))\n", + "result\n", + "# The result confirms that the provider configuration can execute a request." ] }, { @@ -912,61 +1207,55 @@ "id": "5c47562a", "metadata": {}, "source": [ - "## 11) AgentRun\n", - "\n", - "AgentRun objects contain metadata about an agent execution.\n", - "\n", - "This includes:\n", - "\n", - "- token usage\n", - "- message history\n", - "- tool calls\n", - "- final output" + "# AgentRun\n", + "\n", + "- `AgentRun` objects contain metadata about an agent execution\n", + "- `AgentRun` metadata includes:\n", + " - Token usage\n", + " - Message history\n", + " - Tool calls\n", + " - Final output\n", + "- Run metadata helps with:\n", + " - Observability: inspect messages and tool calls\n", + " - Cost tracking: inspect token usage\n", + " - Governance: keep execution details available for review" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 43, "id": "52652ef6", "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Output: A unit test is a type of software testing that involves testing individual components or functions of a program in isolation to ensure they work as intended.\n", - "Messages (new): 2\n", - "Usage: \n" - ] + "data": { + "text/plain": [ + "{'output': 'A unit test is a type of software test that verifies the correctness of a small, specific part of an application, typically a single function or method, to ensure it behaves as expected.',\n", + " 'messages_new': 2,\n", + " 'usage': }" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ + "# Run an agent and collect execution metadata.\n", "meta_agent = Agent(MODEL_ID, instructions=\"Answer in one sentence.\")\n", - "result = await meta_agent.run(\"What is a unit test?\")\n", + "result = asyncio.run(meta_agent.run(\"What is a unit test?\"))\n", + "# Extract execution metadata that helps inspect the run.\n", "usage = getattr(result, \"usage\", None)\n", "message_count = len(result.new_messages())\n", - "print(\"Output:\", result.output)\n", - "print(\"Messages (new):\", message_count)\n", - "print(\"Usage:\", usage)" - ] - }, - { - "cell_type": "markdown", - "id": "86d0d2b4-7b95-40e5-ba82-1a9083c41c2f", - "metadata": {}, - "source": [ - "### What happened in the code\n", - "\n", - "- We ran an agent and inspected the returned result object.\n", - "- The result object can include metadata such as:\n", - " - token usage (cost visibility)\n", - " - message history (debugging)\n", - " - tool calls (auditing agent behavior)\n", - "\n", - "**Why PydanticAI is useful here:**\n", - "When agents behave unexpectedly, metadata is how you debug and control them. This is essential for observability, cost tracking, and governance." + "run_metadata = {\n", + " \"output\": result.output,\n", + " \"messages_new\": message_count,\n", + " \"usage\": usage,\n", + "}\n", + "run_metadata\n", + "# The metadata summarizes output, message count, and usage details." ] }, { @@ -974,88 +1263,109 @@ "id": "ed489922", "metadata": {}, "source": [ - "## 12) Usage limits and model settings\n", + "# Usage Limits and Model Settings\n", "\n", - "Usage limits help control:\n", - "\n", - "- API cost\n", - "- runaway loops\n", - "- excessive token usage" + "- Usage limits help control:\n", + " - API cost\n", + " - Runaway loops\n", + " - Excessive token usage\n", + "- `PydanticAI` supports safety and cost controls for production LLM systems" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 44, "id": "76413843", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Model settings + usage limits:\n", - "Unit tests are automated tests that validate individual components or functions of a software application to ensure they work as intended, typically by checking their outputs against expected results.\n" + "Loaded ModelSettings and UsageLimits classes.\n" ] + }, + { + "data": { + "text/plain": [ + "{'model_settings_class': 'ModelSettings', 'usage_limits_class': 'UsageLimits'}" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from pydantic_ai import Agent\n", - "\n", - "\n", - "# Version-tolerant imports for ModelSettings + UsageLimits\n", - "try:\n", - " # common in newer versions\n", - " from pydantic_ai import ModelSettings, UsageLimits\n", - "except Exception:\n", - " # fallback seen in some versions\n", - " from pydantic_ai.models import ModelSettings # type: ignore\n", - " from pydantic_ai.usage import UsageLimits # type: ignore\n", - "\n", - "\n", - "settings_agent = Agent(\n", - " MODEL_ID,\n", - " instructions=\"Answer in a single sentence.\",\n", - " model_settings=ModelSettings(temperature=0.2),\n", - ")\n", - "\n", - "result = await settings_agent.run(\n", - " \"Explain what unit tests are.\",\n", - " usage_limits=UsageLimits(request_limit=3),\n", - ")\n", - "\n", - "print(\"Model settings + usage limits:\")\n", - "print(result.output)" + "# Load version-tolerant classes for model settings and usage limits.\n", + "ModelSettings, UsageLimits = utils.get_settings_classes()\n", + "_LOG.info(\"Loaded ModelSettings and UsageLimits classes.\")\n", + "{\n", + " \"model_settings_class\": ModelSettings.__name__,\n", + " \"usage_limits_class\": UsageLimits.__name__,\n", + "}\n", + "# The installed PydanticAI version determines where these classes come from." ] }, { - "cell_type": "markdown", - "id": "cc8440b6-4d88-43db-9eac-6d86061d6dc4", + "cell_type": "code", + "execution_count": 45, + "id": "459e5581", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings={'temperature': 0.2}, output_type=, instrument=None)" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "### What happened in the code\n", - "\n", - "- `ModelSettings(temperature=0.2)` controls response randomness:\n", - " - lower temperature = more deterministic outputs\n", - "- `UsageLimits(request_limit=3)` sets guardrails on usage:\n", - " - helps prevent runaway retries or excessive calls\n", - "- We ran the agent with these settings applied.\n", - "\n", - "**Why PydanticAI is useful here:**\n", - "PydanticAI makes it easy to add safety and cost controls to LLM systems. These controls matter in production where reliability and spend both need limits." + "# Create an agent with deterministic model settings.\n", + "settings_agent = Agent(\n", + " MODEL_ID,\n", + " instructions=\"Answer in a single sentence.\",\n", + " model_settings=ModelSettings(temperature=0.2),\n", + ")\n", + "settings_agent\n", + "# The agent has a low-temperature model setting." ] }, { - "cell_type": "markdown", - "id": "cddca283", + "cell_type": "code", + "execution_count": 46, + "id": "ad306084", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Unit tests are automated tests that verify the correctness of individual components or functions of a software application in isolation.'" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "## Best Practices\n", + "# Run the settings example with a request limit.\n", + "result = asyncio.run(\n", + " settings_agent.run(\n", + " \"Explain what unit tests are.\",\n", + " usage_limits=UsageLimits(request_limit=3),\n", + " )\n", + ")\n", "\n", - "1. Always define clear schemas using Pydantic models.\n", - "2. Keep schemas simple and explicit.\n", - "3. Use retries for robustness.\n", - "4. Add tools for external integrations.\n", - "5. Use async execution for production systems." + "# Show the constrained response text.\n", + "result.output\n", + "# The response was generated with model settings and usage limits applied." ] }, { @@ -1063,10 +1373,11 @@ "id": "e1bedef2", "metadata": {}, "source": [ - "## Troubleshooting\n", - "- Missing API key: set `OPENAI_API_KEY` (or your provider-specific key).\n", - "- Event loop errors in notebooks: use `await agent.run(...)` instead of `run_sync`.\n", - "- Validation errors: revise `output_type` or the validator to match expected output.\n" + "# Troubleshooting\n", + "\n", + "- Missing API key: set `OPENAI_API_KEY` or the provider-specific key\n", + "- Event loop errors in notebooks: use `await agent.run(...)` instead of `run_sync`\n", + "- Validation errors: revise `output_type` or the validator to match expected output\n" ] } ], diff --git a/tutorials/tutorial_pydanticAI/pydanticai.API.py b/tutorials/tutorial_pydanticAI/pydanticai.API.py index d15fe8e17..eee1465c4 100644 --- a/tutorials/tutorial_pydanticAI/pydanticai.API.py +++ b/tutorials/tutorial_pydanticAI/pydanticai.API.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.19.0 +# jupytext_version: 1.16.4 # kernelspec: # display_name: Python 3 (ipykernel) # language: python @@ -17,364 +17,352 @@ # %load_ext autoreload # %autoreload 2 +# System libraries. import logging +# Third party libraries. +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt -import helpers.hnotebook as ut +# %% +# System libraries. +import asyncio +import os -ut.config_notebook() +# Third party libraries. +from dataclasses import dataclass -# Initialize logger. -logging.basicConfig(level=logging.INFO) -_LOG = logging.getLogger(__name__) +import nest_asyncio +from dotenv import find_dotenv, load_dotenv +from pydantic import BaseModel +from pydantic_ai import Agent +from pydantic_ai import ModelRetry + +# Local utilities. +import pydanticai_API_utils as utils + +# Notebook-specific imports are ready for tutorial examples. # %% +# Configure notebook logging. +import logging + +# Local utility. import pydanticai_API_utils as utils +_LOG = logging.getLogger(__name__) +utils.init_logger(_LOG) +_LOG +# Notebook logging is configured for the tutorial cells. + # %% [markdown] -# ## PydanticAI API Tutorial Introduction -# -# PydanticAI is a lightweight framework for building LLM-powered applications with **structured outputs using Pydantic models**. -# -# Unlike traditional LLM APIs that return unstructured text, PydanticAI ensures responses conform to a predefined schema. -# -# This notebook covers: -# -# - Core concepts -# - Agent API -# - Structured outputs -# - Tool usage -# - Validation and retries -# - Async execution +# # Summary # -# By the end, you will understand how to build reliable LLM pipelines using structured outputs. +# - This notebook introduces `PydanticAI` APIs for building LLM workflows, including structured outputs, tools, dependencies, validators, streaming, provider configuration, run metadata, and usage limits # %% [markdown] -# # Table of Contents -# -# 1. Introduction -# 2. Why PydanticAI exists -# 3. Installation -# 4. Minimal Example -# 5. Core Concepts -# 6. Structured Outputs -# 7. Validation -# 8. Tools -# 9. Dependencies -# 10. Async Execution -# 11. Advanced Features -# 12. Best Practices -# 13. Summary +# # PydanticAI API Introduction +# +# - `PydanticAI` is a lightweight framework for building LLM-powered applications with structured outputs +# - `PydanticAI` uses `Pydantic` models to define response schemas +# - Traditional LLM APIs often return unstructured text +# - `PydanticAI` keeps responses aligned with a predefined schema # %% [markdown] -# ### Why PydanticAI Exists -# -# LLMs typically return unstructured text. -# -# Example: -# -# User prompt: -# "Extract product information from this description" -# -# LLM output: -# "The product is an iPhone 15 priced at $999." -# -# This output is difficult to use programmatically. +# ## Why PydanticAI Exists # -# What we want instead: +# - Key problem: LLMs typically return unstructured text +# - Example prompt: +# - "Extract product information from this description" +# - Example LLM output: +# - "The product is an iPhone 15 priced at $999." +# - Problem with the example LLM output: +# - The example LLM output is difficult to use programmatically +# - Desired structured output: # -# { -# "product_name": "iPhone 15", -# "price": 999 -# } +# ```json +# { +# "product_name": "iPhone 15", +# "price": 999 +# } +# ``` # -# PydanticAI solves this problem by: -# -# - Defining schemas using **Pydantic models** -# - Enforcing structured outputs -# - Automatically retrying when validation fails -# - Providing a simple agent abstraction for LLM interaction - -# %% [markdown] -# ### Mental Model -# -# ``` -# User Prompt -# v -# PydanticAI Agent -# v -# LLM -# v -# Raw Response -# v -# Pydantic Validation -# v -# Structured Output -# ``` +# - `PydanticAI` solves this problem with: +# - Schema definitions with `Pydantic` models +# - Structured output enforcement +# - Automatic retries after validation failures +# - A simple agent abstraction for LLM interaction # %% [markdown] -# ## Installation +# ## Mental Model # -# We install a minimal set of packages to keep the notebook self-contained and reproducible. -# -# This notebook uses `pydantic-ai`, `pydantic`, and `python-dotenv`. -# - -# %% -# !pip install -q pydantic-ai +# - `PydanticAI` flow: +# ```mermaid +# flowchart TD +# A[User Prompt] --> B[PydanticAI Agent] +# B --> C[LLM] +# C --> D[Raw Response] +# D --> E[Pydantic Validation] +# E --> F[Structured Output] +# ``` # %% -import os -from dotenv import load_dotenv, find_dotenv -import nest_asyncio - -nest_asyncio.apply() - - +# Load environment variables from a local dotenv file if one exists. env_path = find_dotenv(usecwd=True) load_dotenv(env_path, override=True) +_LOG.info("dotenv path: %s", env_path or "") +env_path or "" +# Environment variables are available to the model configuration cells. -MODEL_ID = os.getenv("PYDANTIC_AI_MODEL", "openai:gpt-4.1-mini") -print("dotenv path:", env_path or "") -print("PYDANTIC_AI_MODEL:", MODEL_ID) -print("OPENAI_API_KEY:", utils._mask(os.getenv("OPENAI_API_KEY"))) +# %% +# Read the model identifier from the environment. +MODEL_ID = os.getenv("PYDANTIC_AI_MODEL") +utils.log_environment(env_path, MODEL_ID) +{"model_id": MODEL_ID} +# The tutorial examples will use the configured model identifier. # %% [markdown] -# ### Running the Notebook +# # Core Concepts +# +# - `PydanticAI` revolves around a few important abstractions +# +# ## Agent +# +# - `Agent` is the main interface for interacting with the model +# - `Agent` manages: +# - LLM calls +# - Structured outputs +# - Retries +# - Tool usage +# +# ## output_type +# +# - `output_type` defines the expected structured output +# - `output_type` must be a `Pydantic` model +# +# ## Tools +# +# - Tools are functions that the agent can call during reasoning +# - Tools let agents interact with external systems such as APIs or databases # -# To run the examples you must set your API key. # -# Example: -# ``` -# export OPENAI_API_KEY="your_key_here" -# ``` # %% [markdown] -# ## Minimal Example -# -# The quickest way to understand PydanticAI is through a small example. +# # Minimal Example # -# We define a schema using Pydantic and instruct the agent to produce that structured output. +# - The quickest way to understand `PydanticAI` is a small example +# - This section defines a schema with `Pydantic` and asks the agent to produce that structured output # %% -from pydantic import BaseModel -from pydantic_ai import Agent - - +# Define the output schema for the minimal example. class City(BaseModel): name: str country: str population: int -agent = Agent("openai:gpt-4o-mini", output_type=City) +City +# The schema defines the exact output shape expected from the model. + +# %% +# Create an agent that must return `City`. +agent = Agent(MODEL_ID, output_type=City) +agent +# The agent is configured to validate model output against class `City`. + +# %% +# Run the minimal example agent. result = agent.run_sync("Tell me about Paris") result.output +# The result is a validated `City` object. # %% [markdown] -# ### What Happened? +# # Resolving the Above RuntimeError in Jupyter # -# 1. A Pydantic schema (`City`) defines the expected output structure. -# 2. The `Agent` sends the prompt to the LLM. -# 3. The LLM response is validated against the schema. -# 4. If validation succeeds, the structured result is returned. +# - Key thing to remember: Jupyter already runs an active event loop # %% [markdown] -# ## Core Concepts -# -# PydanticAI revolves around a few important abstractions. -# -# ### Agent -# -# The `Agent` is the main interface for interacting with the model. -# -# It manages: -# -# - LLM calls -# - structured outputs -# - retries -# - tool usage -# -# ### output_type -# -# Defines the expected structured output. -# -# This must be a Pydantic model. -# -# ### Tools -# -# Functions that the agent can call during reasoning. -# -# Tools allow agents to interact with external systems such as APIs or databases. -# -# - -# %% [markdown] -# ## Structured Outputs with Pydantic +# - `agent.run_sync()` can raise a `RuntimeError` in notebook environments +# - `nest_asyncio` patches the notebook event loop so nested async execution can work +# - After `nest_asyncio.apply()`, async `PydanticAI` examples can run inside notebook cells # %% -from pydantic import BaseModel +# Enable nested event loops for notebook execution. +nest_asyncio.apply() +nested_event_loop_enabled = True +_LOG.info("Nested event loop support enabled.") +nested_event_loop_enabled +# Async PydanticAI examples can now run from notebook cells. +# %% [markdown] +# - Re-run the previous cell that raised the `RuntimeError` +# %% [markdown] +# # Structured Outputs with Pydantic +# +# - `PydanticAI` turns LLM responses into structured data +# - Structured outputs help you: +# - Store validated outputs in databases +# - Feed typed objects into analytics +# - Pass structured data downstream without brittle string parsing + +# %% +# Define a product schema for structured extraction. class Product(BaseModel): name: str price: float category: str -agent = Agent("openai:gpt-4o-mini", output_type=Product) +Product +# The schema captures the product fields we want to extract. -agent.run_sync("Describe the Apple AirPods Pro").output +# %% +# Create an agent that must return `Product`. +agent = Agent(MODEL_ID, output_type=Product) +agent +# The agent is configured to return product data with typed fields. -# %% [markdown] -# ### What happened in the code -# -# - We defined a `Product` schema (name, price, category). -# - The agent is configured to produce outputs that conform to this schema. -# - When the model answers, PydanticAI validates that: -# - `price` is a number -# - fields exist with the right types -# - the structure matches exactly -# -# **Why PydanticAI is useful here:** -# This turns LLM responses into structured data you can store in databases, feed into analytics, or pass downstream in an application without brittle string parsing. +# %% +# Ask the model for structured product information. +agent.run_sync("Describe the Apple AirPods Pro").output +# The response is validated as a `Product` class object. # %% [markdown] -# ## Validation and Retries -# -# If the LLM produces an output that does not match the schema, PydanticAI automatically retries. +# # Validation and Retries # -# This greatly improves reliability. - +# - Real LLM outputs are inconsistent +# - Schema validation checks the generated structure +# - Retries let `PydanticAI` ask the model to repair invalid output +# - This notebook avoids custom parsing and retry logic in each prompt # %% +# Define a schema that requires an integer age. class Person(BaseModel): name: str age: int -agent = Agent("openai:gpt-4o-mini", output_type=Person, retries=2) +Person +# The schema enforces integer typing for age values. -agent.run_sync("Tell me about Albert Einstein") +# %% +# Configure retries so schema validation failures can be corrected. +agent = Agent(MODEL_ID, output_type=Person, retries=2) +agent +# The agent can retry when model output does not match `Person`. -# %% [markdown] -# ### What happened in the code -# -# - We defined a `Person` schema with `name` and `age`. -# - We set `retries=2` on the agent. -# - If the model output fails schema validation (missing fields, wrong types), PydanticAI automatically retries the model call to get a valid output. -# -# **Why PydanticAI is useful here:** -# Real LLM outputs are inconsistent. Automatic schema validation + retry gives you reliability without writing custom parsing and retry logic for every prompt. +# %% +# Run the retry-enabled agent. +agent.run_sync("Tell me about Albert Einstein") +# The result is a validated `Person` run result. # %% [markdown] -# ## Tools +# # Tools # -# Agents can call Python functions as tools. +# - Agents can call Python functions as tools +# - Tools let the model interact with real functions and external systems +# - Tools are useful for APIs, databases, calculations, and deterministic helpers +# - Tool calls reduce the chance that the model invents facts # %% -agent = Agent("openai:gpt-4o-mini", tools=[utils.get_weather]) +# Create an agent with a deterministic weather tool. +agent = Agent(MODEL_ID, tools=[utils.get_weather]) +agent +# The agent can call `utils.get_weather()` while answering. +# %% +# Ask a question that should use the weather tool. agent.run_sync("What is the weather in Tokyo?") +# The run result includes the tool-backed weather answer. # %% [markdown] -# ### What happened in the code -# -# - We defined a Python function `get_weather(city)` that returns a deterministic string. -# - We passed it into the agent via `tools=[get_weather]`. -# - When the user asks about weather, the agent can choose to call the tool to get the answer instead of hallucinating. -# -# **Why PydanticAI is useful here:** -# Tools let the model interact with real functions and external systems. This is how you build agents that do real work (APIs, databases, calculations) rather than confidently inventing facts. - -# %% [markdown] -# ## Dependencies +# # Dependencies # -# Dependencies allow agents to access external resources or shared state. +# - Dependencies inject runtime context into agents and tools +# - Example dependency values: +# - Tenant IDs +# - API clients +# - Feature flags +# - Environment context +# - Dependencies let tools access context without global variables or prompt string formatting # %% -from dataclasses import dataclass -from pydantic_ai import Agent - - +# Define the dependency object passed into the agent at run time. @dataclass class Config: company: str -agent = Agent("openai:gpt-4o-mini", deps_type=Config, tools=[utils.company_name]) +Config +# The dependency schema describes runtime context available to tools. + +# %% +# Create an agent that receives `Config` dependencies. +# `deps_type=Config` declares the shape of runtime context the agent can receive. +agent = Agent(MODEL_ID, deps_type=Config, tools=[utils.company_name]) +agent +# Tools can access `Config` through the PydanticAI run context. + +# %% +# Run the dependency-aware agent with a concrete configuration. result = agent.run_sync( "What company is configured?", deps=Config(company="OpenAI") ) -print(result.output) - -# %% [markdown] -# ### What happened in the code -# -# - `deps_type=Config` declares the *shape* of runtime context the agent can receive. -# - At run time, we pass an instance like `Config(company="OpenAI")`. -# - Tools (or other agent logic) can access this via `RunContext.deps`, so the agent can use configuration/state without hardcoding it into prompts. -# -# **Why PydanticAI is useful here:** -# Dependencies are a clean way to inject runtime configuration (tenant ID, API clients, feature flags, environment context) into agents and tools without relying on global variables or string formatting prompts. +result.output +# The answer reflects the runtime dependency value. # %% [markdown] -# ## Async Execution +# # Advanced Features # -# PydanticAI supports asynchronous execution for scalable applications. - -# %% -import asyncio - -asyncio.run(utils.run_agent(agent)) +# - The following sections demonstrate more advanced `PydanticAI` capabilities +# - These features are useful for production-grade systems: +# - Custom validation +# - Streaming outputs +# - Model configuration +# - Usage tracking +# - Runtime limits +# - Beginners can safely skip this section on a first read # %% [markdown] -# ### What happened in the code +# # Result Validators # -# - We defined an async function that calls `await agent.run(...)`. -# - Async execution is helpful for applications that need concurrency (web servers, batch pipelines, background jobs). -# - `asyncio.run(...)` runs the coroutine in a notebook-safe way. +# - Result validators are used to check model outputs after schema validation +# - `Pydantic` validates structure automatically, but result validators enforce business rules +# - A response can match the `Pydantic` schema and still fail logical constraints +# - For example, this output may be valid according to the schema: +# - it has an `answer` +# - it has a `sources` list +# - But it can still be logically wrong if: +# - the source list is empty +# - the `doc_id` does not exist +# - the quote does not actually appear in the cited document # -# **Why PydanticAI is useful here:** -# Most real systems are async. PydanticAI supports async natively, so you can run many agent calls concurrently without blocking your app. +# - Result validators handle this second layer of validation # %% [markdown] -# ## Advanced API Features +# ## Validation Flow # -# The following sections demonstrate more advanced capabilities of PydanticAI. -# -# These features are useful when building production-grade systems: -# -# - custom validation -# - streaming outputs -# - model configuration -# - usage tracking -# - runtime limits -# -# Beginners can safely skip this section on a first read. - -# %% [markdown] -# ## Result Validators -# -# Result validators allow you to enforce additional rules on model outputs. -# -# Even if the response matches the Pydantic schema, we may still want to verify -# logical constraints. -# -# Example: if an answer claims to use documents, it must include at least one source. +# - Validation happens in two stages: +# - `Schema validation`: the model output must match `AnswerWithSources` +# - `Business-rule validation`: the registered `output_validator` enforces citation quality rules that schema alone cannot enforce +# - Execution order: +# ```mermaid +# flowchart LR +# A[Model Output] --> B[Pydantic Schema Validation] +# B --> C[output_validator] +# C --> D[Final Result] +# ``` # %% -from pydantic import BaseModel -from pydantic_ai import Agent - -MODEL_ID = "openai:gpt-4o-mini" - - +# Define source citation schemas with explicit references for validator examples. class SourceRef(BaseModel): doc_id: str quote: str @@ -385,245 +373,243 @@ class AnswerWithSources(BaseModel): sources: list[SourceRef] -validator_agent = Agent( - MODEL_ID, - output_type=AnswerWithSources, - instructions=( - "Answer with short factual statements. " - "If you reference documents, include sources." - ), -) -validator_agent.output_validator(utils.validate_sources) +AnswerWithSources +# The schemas describe answers that include source citations. +# %% [markdown] +# ## Prepare Validation Context +# +# - We fetch the list of valid document IDs and include it in the agent instructions +# - This helps: +# - reduce hallucinated references +# - constrain the model to known documents # %% -try: - utils.validate_sources( - AnswerWithSources(answer="According to the documents...", sources=[]) - ) -except Exception as e: - print("Validator failure example:", e) +# Build validator instructions from local document ids. +available_doc_ids = utils.get_available_document_ids() +# Build instructions that restrict citations to the local dataset. +validator_instructions = ( + "Use the search_documents tool to retrieve evidence from local documents. " + f"Cite only these doc ids: {available_doc_ids}. " + "For each source, copy the quote text exactly from tool output." +) +{ + "available_doc_ids": available_doc_ids, + "validator_instruction_length": len(validator_instructions), +} +# The instructions constrain citations to the local document ids. # %% [markdown] -# ### What happened in the code +# ### Create the Validator Agent +# - This agent: +# - generates structured output +# - retrieves documents using a tool +# - follows constrained citation rules # -# - We defined a schema `AnswerWithSources` where the model must return: -# - `answer` (string) -# - `sources` (list of `{doc_id, quote}`) -# - We attached an `output_validator` that enforces *logical rules* beyond the schema: -# - if the answer mentions docs, sources must not be empty -# - max 3 sources -# - no duplicate sources -# - If rules fail, we raise `ModelRetry`, which tells PydanticAI to retry the model call. # -# **Why PydanticAI is useful here:** -# Schemas catch structural mistakes. Validators catch logical mistakes. Together, they make LLM outputs production-grade by enforcing business rules automatically. + +# %% +# Create an agent that returns answers with source references. +# The agent uses structured output plus the local document-search tool. +validator_agent = Agent( + MODEL_ID, + output_type=AnswerWithSources, + instructions=validator_instructions, + tools=[utils.search_documents], +) +validator_agent +# The validator agent can retrieve documents and return cited answers. + # %% [markdown] -# ### Validator Failure Example -# -# The validator can also be tested manually. +# ## Add Result Validator # -# If the validation rule fails, the validator raises `ModelRetry`, which instructs the agent to retry the LLM call with improved instructions. +# - The `@output_validator` runs after schema validation and enforces business rules: +# - sources must be present +# - document IDs must exist +# - quotes must match source documents +# - duplicates are not allowed +# - If validation fails, `ModelRetry` is raised, and the model is asked to generate a corrected answer. # %% -import asyncio +# Register a result validator that checks citations against local documents. +@validator_agent.output_validator +def _validate_answer_sources( + result: AnswerWithSources, +) -> AnswerWithSources: + # Validate citations against the local document dataset. + validated_result = utils.validate_document_sources(result) + return validated_result + + +{"validator_registered": True} +# The validator agent now enforces schema and source-reference rules. -asyncio.run(utils.run_validator_example(validator_agent)) # %% [markdown] -# ## Streaming -# -# Streaming allows tokens to be returned as they are generated. -# -# Benefits: +# ## Manual Failure Example # -# - lower perceived latency -# - better user experience in chat interfaces -# - progressive display of responses +# - We intentionally create an invalid output to demonstrate how the validator triggers a retry. +# - This example bypasses the model and directly tests the validator logic. # %% -stream_agent = Agent( - MODEL_ID, instructions="Write one short paragraph about unit tests." +# Build an invalid answer object for the validator demo. +bad_answer = AnswerWithSources( + answer="PydanticAI supports structured outputs.", + sources=[], ) +bad_answer +# The invalid answer is missing source citations. -if not hasattr(stream_agent, "run_stream"): - print("Streaming API not available; falling back to run().") - result = await stream_agent.run("What are unit tests?") - _print_result("Non-streamed:", result) -else: - try: - async with stream_agent.run_stream("What are unit tests?") as stream: - print("Streaming:") - async for chunk in stream.stream_text(): - print(chunk, end="", flush=True) - print("---") - result = await stream.get_final_result() - print("\n\nFinal result:", result) - except Exception as e: - print("Streaming failed; falling back to run().", e) - result = await stream_agent.run("What are unit tests?") - print("\n\nNon-streamed:", result) - +# %% +# Trigger the validator on the intentionally invalid answer. +_LOG.info("Triggering the validator with an intentionally invalid answer.") +_validate_answer_sources(bad_answer) +# The validator raises `ModelRetry` for the missing sources. # %% [markdown] -# ### What happened in the code -# -# - We created an agent and attempted to call the model using streaming mode. -# - With streaming, tokens are yielded as the model generates them instead of waiting for the full response. -# - This improves perceived responsiveness for chat apps and UIs. +# ## Run the Agent # -# **Why PydanticAI is useful here:** -# Streaming helps build better user experiences. You can display partial output instantly while the model continues generating, which is critical for interactive assistants. +# - The agent will: +# - Generate structured output +# +# - Validate it against the schema +# +# - Apply business rules +# +# - Retry automatically if validation fails + +# %% +# Run the validator agent with the local document search tool. +validator_result = asyncio.run(utils.run_validator_example(validator_agent)) +validator_result +# The validator agent returns a cited answer that passed validation. # %% [markdown] -# ## Provider Configuration -# -# Model objects let you configure providers directly (e.g., base URLs). -# -# You can supply an explicit model object instead of a string ID. This is where you would set provider-specific options (e.g., `base_url`). +# # Streaming # +# - Streaming returns tokens as the model generates them +# - Streaming benefits: +# - Lower perceived latency +# - Better user experience in chat interfaces +# - Progressive display of responses # %% -explicit_model = None -try: - from pydantic_ai.models.openai import OpenAIModel - - explicit_model = OpenAIModel( - model=MODEL_ID.split(":", 1)[-1], - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_BASE_URL"), - ) - print("Using explicit OpenAIModel.") -except Exception: - try: - from pydantic_ai.models.openai import OpenAIChatModel - - explicit_model = OpenAIChatModel( - model=MODEL_ID.split(":", 1)[-1], - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_BASE_URL"), - ) - print("Using explicit OpenAIChatModel.") - except Exception as e2: - print("Explicit model unavailable; using string model ID.", e2) - -agent = Agent(explicit_model or MODEL_ID, instructions="Be concise.") -try: - result = await agent.run("Say hello in one sentence.") - print("Explicit model (or fallback):", result) -except Exception as e: - print("Error: ", e) +# Create an agent for the streaming example. +stream_agent = Agent( + MODEL_ID, instructions="Write one short paragraph about unit tests." +) +stream_agent +# The streaming agent is ready to produce incremental text. +# %% +# Run the streaming helper and return the final result. +asyncio.run(utils.run_streaming_demo(stream_agent)) +# The helper logs streamed text and returns the final result. # %% [markdown] -# ### What happened in the code +# # Provider Configuration # -# - Instead of using a string model ID, we attempted to create an explicit provider model object. -# - This allows provider-specific configuration such as: -# - custom base URLs -# - custom API keys -# - proxy settings -# - If explicit model classes aren't available in the installed version, we fall back to using the string model ID. +# - Model objects let you configure providers directly, such as `base_url` +# - Use an explicit model object when provider-specific options are needed # -# **Why PydanticAI is useful here:** -# Explicit provider configuration is what you use in real deployments: enterprise gateways, self-hosted endpoints, proxies, and custom routing. + +# %% +# Build an explicit provider model object when the installed API supports it. +explicit_model = utils.build_explicit_openai_model(MODEL_ID) +# Log which provider configuration path is active. +if explicit_model is None: + _LOG.info("Explicit model unavailable; using string model ID.") +else: + _LOG.info("Using explicit model object.") +{"explicit_model_available": explicit_model is not None} +# Provider configuration is either explicit or falls back to `MODEL_ID`. + +# %% +# Run an agent with the explicit provider model when available. +agent = Agent(explicit_model or MODEL_ID, instructions="Be concise.") +result = asyncio.run(agent.run("Say hello in one sentence.")) +result +# The result confirms that the provider configuration can execute a request. + # %% [markdown] -# ## 11) AgentRun -# -# AgentRun objects contain metadata about an agent execution. -# -# This includes: +# # AgentRun # -# - token usage -# - message history -# - tool calls -# - final output +# - `AgentRun` objects contain metadata about an agent execution +# - `AgentRun` metadata includes: +# - Token usage +# - Message history +# - Tool calls +# - Final output +# - Run metadata helps with: +# - Observability: inspect messages and tool calls +# - Cost tracking: inspect token usage +# - Governance: keep execution details available for review # %% +# Run an agent and collect execution metadata. meta_agent = Agent(MODEL_ID, instructions="Answer in one sentence.") -result = await meta_agent.run("What is a unit test?") +result = asyncio.run(meta_agent.run("What is a unit test?")) +# Extract execution metadata that helps inspect the run. usage = getattr(result, "usage", None) message_count = len(result.new_messages()) -print("Output:", result.output) -print("Messages (new):", message_count) -print("Usage:", usage) +run_metadata = { + "output": result.output, + "messages_new": message_count, + "usage": usage, +} +run_metadata +# The metadata summarizes output, message count, and usage details. # %% [markdown] -# ### What happened in the code +# # Usage Limits and Model Settings # -# - We ran an agent and inspected the returned result object. -# - The result object can include metadata such as: -# - token usage (cost visibility) -# - message history (debugging) -# - tool calls (auditing agent behavior) -# -# **Why PydanticAI is useful here:** -# When agents behave unexpectedly, metadata is how you debug and control them. This is essential for observability, cost tracking, and governance. - -# %% [markdown] -# ## 12) Usage limits and model settings -# -# Usage limits help control: -# -# - API cost -# - runaway loops -# - excessive token usage +# - Usage limits help control: +# - API cost +# - Runaway loops +# - Excessive token usage +# - `PydanticAI` supports safety and cost controls for production LLM systems # %% -from pydantic_ai import Agent - - -# Version-tolerant imports for ModelSettings + UsageLimits -try: - # common in newer versions - from pydantic_ai import ModelSettings, UsageLimits -except Exception: - # fallback seen in some versions - from pydantic_ai.models import ModelSettings # type: ignore - from pydantic_ai.usage import UsageLimits # type: ignore +# Load version-tolerant classes for model settings and usage limits. +ModelSettings, UsageLimits = utils.get_settings_classes() +_LOG.info("Loaded ModelSettings and UsageLimits classes.") +{ + "model_settings_class": ModelSettings.__name__, + "usage_limits_class": UsageLimits.__name__, +} +# The installed PydanticAI version determines where these classes come from. +# %% +# Create an agent with deterministic model settings. settings_agent = Agent( MODEL_ID, instructions="Answer in a single sentence.", model_settings=ModelSettings(temperature=0.2), ) +settings_agent +# The agent has a low-temperature model setting. -result = await settings_agent.run( - "Explain what unit tests are.", - usage_limits=UsageLimits(request_limit=3), +# %% +# Run the settings example with a request limit. +result = asyncio.run( + settings_agent.run( + "Explain what unit tests are.", + usage_limits=UsageLimits(request_limit=3), + ) ) -print("Model settings + usage limits:") -print(result.output) - -# %% [markdown] -# ### What happened in the code -# -# - `ModelSettings(temperature=0.2)` controls response randomness: -# - lower temperature = more deterministic outputs -# - `UsageLimits(request_limit=3)` sets guardrails on usage: -# - helps prevent runaway retries or excessive calls -# - We ran the agent with these settings applied. -# -# **Why PydanticAI is useful here:** -# PydanticAI makes it easy to add safety and cost controls to LLM systems. These controls matter in production where reliability and spend both need limits. +# Show the constrained response text. +result.output +# The response was generated with model settings and usage limits applied. # %% [markdown] -# ## Best Practices +# # Troubleshooting # -# 1. Always define clear schemas using Pydantic models. -# 2. Keep schemas simple and explicit. -# 3. Use retries for robustness. -# 4. Add tools for external integrations. -# 5. Use async execution for production systems. - -# %% [markdown] -# ## Troubleshooting -# - Missing API key: set `OPENAI_API_KEY` (or your provider-specific key). -# - Event loop errors in notebooks: use `await agent.run(...)` instead of `run_sync`. -# - Validation errors: revise `output_type` or the validator to match expected output. +# - Missing API key: set `OPENAI_API_KEY` or the provider-specific key +# - Event loop errors in notebooks: use `await agent.run(...)` instead of `run_sync` +# - Validation errors: revise `output_type` or the validator to match expected output # diff --git a/tutorials/tutorial_pydanticAI/pydanticai.example.ipynb b/tutorials/tutorial_pydanticAI/pydanticai.example.ipynb index 6ae3d62ab..9cd519423 100644 --- a/tutorials/tutorial_pydanticAI/pydanticai.example.ipynb +++ b/tutorials/tutorial_pydanticAI/pydanticai.example.ipynb @@ -2,172 +2,297 @@ "cells": [ { "cell_type": "code", - "execution_count": null, - "id": "847ee414", - "metadata": {}, - "outputs": [], + "execution_count": 84, + "id": "b5b76493", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", + "# System libraries.\n", "import logging\n", "\n", + "# Third party libraries.\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from dotenv import find_dotenv, load_dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "84f2685e", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# Import notebook-specific libraries.\n", + "import asyncio\n", + "import os\n", + "from dataclasses import dataclass\n", + "from pathlib import Path\n", + "from typing import Optional\n", "\n", - "import helpers.hnotebook as ut\n", - "\n", - "ut.config_notebook()\n", + "from IPython.display import Markdown, display\n", + "import nest_asyncio\n", + "from pydantic import BaseModel, Field\n", + "from pydantic_ai import Agent, RunContext\n", "\n", - "# Initialize logger.\n", - "logging.basicConfig(level=logging.INFO)\n", - "_LOG = logging.getLogger(__name__)" + "import helpers.hio as hio" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "6a6f5a5d", - "metadata": {}, + "execution_count": 86, + "id": "bc2d1f38", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "MODEL_ID: openai:gpt-5-2025-08-07\n", - "OPENAI_API_KEY set: True\n" + "\u001b[0m\u001b[33mWARNING\u001b[0m: Logger already initialized: skipping\n" ] + }, + { + "data": { + "text/plain": [ + "'Notebook logging initialized.'" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "import pydanticai_example_utils as utils" + "import logging\n", + "\n", + "# Local utility.\n", + "import pydanticai_example_utils as utils\n", + "\n", + "_LOG = logging.getLogger(__name__)\n", + "utils.init_logger(_LOG)\n", + "\n", + "display(\"Notebook logging initialized.\")\n", + "# Notebook and utility logging are now configured." ] }, { "cell_type": "markdown", - "id": "67a9d3fc", + "id": "b457eec4", "metadata": {}, "source": [ - "# PydanticAI Example Notebook: Atlas Support Assistant (E2E)\n", + "# Summary\n", "\n", - "This notebook builds a small \"support assistant\" for a synthetic product called **Atlas**.\n", + "- This notebook shows how to build a grounded Atlas support assistant with retrieval, structured outputs, validation, guardrails, and personalization\n", "\n", - "We will:\n", - "1. Generate a synthetic knowledge base (Markdown docs)\n", - "2. Load + chunk the docs\n", - "3. Build a simple local embedding index (no external embedding service required)\n", - "4. Add retrieval as a **PydanticAI tool**\n", - "5. Use **structured outputs** (Pydantic schema) with **citations**\n", - "6. Add **validators** to enforce rules like \"citations required\"\n", - "7. Add optional **guardrails** and **personalization**\n", + "# PydanticAI Example Notebook: Atlas Support Assistant (E2E)\n", + "\n", + "- Goal: build a small support assistant for the synthetic product **Atlas**\n", + "- Workflow:\n", + " - Generate a synthetic knowledge base\n", + " - Load and chunk the docs\n", + " - Build a local embedding index\n", + " - Add retrieval as a **PydanticAI** tool\n", + " - Use structured outputs with citations\n", + " - Add validators, guardrails, and personalization\n", "\n", - "The result is an end-to-end pattern you can reuse for real RAG assistants." + "- Outcome: an end-to-end pattern you can reuse for real retrieval-augmented assistants\n" ] }, { "cell_type": "markdown", - "id": "5158c5be", + "id": "5a42e049", "metadata": {}, "source": [ "## Setup\n", "\n", - "This cell initializes the environment and imports all required libraries.\n", - "\n", - "PydanticAI agents need:\n", - "- a model identifier (for example `openai:gpt-4o-mini`)\n", - "- a provider API key (for example `OPENAI_API_KEY`)\n", - "\n", - "Everything else in this notebook is local and self-contained.\n" + "- `PydanticAI` agents need:\n", + " - A model identifier, such as `openai:gpt-4o-mini`\n", + " - A provider API key, such as `OPENAI_API_KEY`\n", + "- Create a ```.env``` file containing these variables to be called in the notebook\n", + "- Everything else in this notebook is local and self-contained\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "37e74d34", - "metadata": {}, - "outputs": [], + "execution_count": 87, + "id": "de74ea5a", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Nested event loop support enabled.'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "!pip install -q pydantic-ai" + "# Enable nested event loops so async agent calls run inside the notebook.\n", + "nest_asyncio.apply()\n", + "\n", + "display(\"Nested event loop support enabled.\")\n", + "# Async notebook execution is now configured." ] }, { "cell_type": "code", - "execution_count": null, - "id": "84ba29d9", + "execution_count": 88, + "id": "ded986c6", "metadata": { "lines_to_next_cell": 2 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "import os\n", - "import functools\n", - "from pathlib import Path\n", - "from dataclasses import dataclass\n", - "from typing import Optional\n", - "\n", - "import nest_asyncio\n", + "# Run notebook coroutines through the current event loop so the paired Python file compiles.\n", + "def _run_async(awaitable):\n", + " return asyncio.get_event_loop().run_until_complete(awaitable)\n", "\n", - "nest_asyncio.apply()\n", - "\n", - "from pydantic import BaseModel, Field\n", - "from pydantic_ai import Agent\n", "\n", - "MODEL_ID = os.getenv(\"PYDANTIC_AI_MODEL\", \"openai:gpt-4o-mini\")\n", - "print(\"MODEL_ID:\", MODEL_ID)\n", - "print(\"OPENAI_API_KEY set:\", bool(os.getenv(\"OPENAI_API_KEY\")))" + "display(_run_async)\n", + "# Notebook async calls can now run without top-level await statements." + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "00e57d38", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dotenv path: /git_root/tutorials/tutorial_pydanticAI/.env\n" + ] + }, + { + "data": { + "text/plain": [ + "'/git_root/tutorials/tutorial_pydanticAI/.env'" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Load environment variables from a local dotenv file if one exists.\n", + "env_path = find_dotenv(usecwd=True)\n", + "load_dotenv(env_path, override=True)\n", + "_LOG.info(\"dotenv path: %s\", env_path or \"\")\n", + "env_path or \"\"\n", + "# Environment variables are available to the model configuration cells." ] }, { "cell_type": "markdown", - "id": "4ea9339f", + "id": "4738099c", "metadata": {}, "source": [ "## Data and Scenario\n", "\n", - "We build a tiny product docs corpus to keep the tutorial self-contained.\n", - "\n", - "We will build a tiny documentation set for an imaginary product called **Atlas**.\n" + "- This notebook uses a small product-docs corpus to stay self-contained\n", + "- The corpus describes an imaginary product called **Atlas**\n" ] }, { "cell_type": "markdown", - "id": "9f9572c7-ea17-44f3-b74f-773151df1aa5", + "id": "619d7bce", "metadata": {}, "source": [ "### What this cell does\n", "\n", - "- Creates a local folder `example_dataset/` and writes a small set of **synthetic product/support documents** as Markdown files.\n", - "- Each file represents a support knowledge-base article (billing, troubleshooting, security, limits, etc.).\n", - "- The dataset is intentionally small but diverse so retrieval can return the *right* document depending on the question.\n", - "\n", - "### Importance\n", - "\n", - "PydanticAI becomes most useful when the agent is grounded in external context (RAG-style).\n", - "These documents act as that context. In the next steps, we will:\n", - "\n", - "1. Load these Markdown files into memory\n", - "2. Retrieve relevant chunks for a user query\n", - "3. Use a PydanticAI agent + tools to answer using retrieved text\n", - "4. Return a structured output with citations" + "- Creates a local folder `example_dataset/` and writes a small set of synthetic support documents\n", + "- Uses one file per support knowledge-base article\n", + "- Keeps the dataset small so retrieval behavior stays easy to inspect\n" ] }, { "cell_type": "code", - "execution_count": 36, - "id": "7d92aad9", - "metadata": {}, + "execution_count": 90, + "id": "a215aefd", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Docs directory: example_dataset\n", - "Files: ['limits.md', 'support.md', 'api.md', 'overview.md', 'billing.md', 'troubleshooting.md', 'security.md', 'integrations.md']\n" - ] + "data": { + "text/plain": [ + "PosixPath('example_dataset')" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ + "# Create the local directory that stores the synthetic support documents.\n", "DOCS_DIR = Path(\"example_dataset/\")\n", - "DOCS_DIR.mkdir(parents=True, exist_ok=True)\n", + "hio.create_dir(str(DOCS_DIR), incremental=True)\n", "\n", + "display(DOCS_DIR)\n", + "# The example dataset directory is now available." + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "00adeefb", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['billing.md',\n", + " 'limits.md',\n", + " 'overview.md',\n", + " 'security.md',\n", + " 'support.md',\n", + " 'troubleshooting.md']" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Define the synthetic Atlas support documents used throughout the notebook.\n", "DOCS = {\n", " \"overview.md\": \"\"\"\n", "# Atlas Overview\n", @@ -237,67 +362,93 @@ "\"\"\",\n", "}\n", "\n", + "display(sorted(DOCS))\n", + "# The notebook now has a compact synthetic document corpus." + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "fa2e8c1a", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['api.md',\n", + " 'billing.md',\n", + " 'integrations.md',\n", + " 'limits.md',\n", + " 'overview.md',\n", + " 'security.md',\n", + " 'support.md',\n", + " 'troubleshooting.md']" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Materialize the synthetic documents on disk if they do not already exist.\n", "for name, text in DOCS.items():\n", " path = DOCS_DIR / name\n", " if not path.exists():\n", - " path.write_text(text.strip() + \"\\n\")\n", + " path.write_text(text.strip() + \"\\n\", encoding=\"utf-8\")\n", "\n", - "print(\"Docs directory:\", DOCS_DIR)\n", - "print(\"Files:\", [p.name for p in DOCS_DIR.glob(\"*.md\")])" + "display(sorted(p.name for p in DOCS_DIR.glob(\"*.md\")))\n", + "# The synthetic knowledge-base files are now stored on disk." ] }, { "cell_type": "markdown", - "id": "455c6119-725e-4dd2-a01c-e272f52d948f", + "id": "2aa9ab53", "metadata": {}, "source": [ - "We load all Markdown files into a standard in-memory format:\n", + "- We load Markdown files into a standard in-memory format:\n", "\n", - "- `doc_id`: stable identifier for citations\n", - "- `title`: human-readable name\n", - "- `text`: document content\n", + " - `doc_id`: stable identifier for citations\n", + " - `title`: human-readable name\n", + " - `text`: document content\n", "\n", - "A consistent document schema makes it easy to:\n", - "- pass documents into dependencies (`deps`)\n", - "- build retrieval tools\n", - "- return structured citations in the agent output" + "- A consistent document schema makes it easier to build retrieval tools and return structured citations\n" ] }, { "cell_type": "markdown", - "id": "5b22ab75", + "id": "bd228906", "metadata": {}, "source": [ "## Chunking and Local Embeddings\n", "\n", - "We split each document into chunks and compute a deterministic vector for each chunk.\n", - "\n", - "### Why this approach\n", - "- It is fully local and reproducible (no external embedding API required)\n", - "- It is good enough to demonstrate retrieval and grounding\n", - "\n", - "### Importance\n", - "PydanticAI agents become far more reliable when they can retrieve relevant context via tools instead of guessing." + "- We split each document into chunks and computes a deterministic vector for each chunk\n", + "- This helps:\n", + " - Ensure it is fully local and reproducible\n", + " - Ensure it is good enough to demonstrate retrieval and grounding\n" ] }, { "cell_type": "code", - "execution_count": 37, - "id": "2f2c92cc", + "execution_count": 93, + "id": "a01b0783", "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chunks: 8\n", - "Example: api 0\n" - ] + "data": { + "text/plain": [ + "__main__.DocChunk" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ + "# Define the chunk schema used for retrieval and citations.\n", "@dataclass\n", "class DocChunk:\n", " doc_id: str\n", @@ -306,42 +457,79 @@ " vector: list[float]\n", "\n", "\n", - "docs = utils.load_docs(DOCS_DIR)\n", - "chunks = utils.chunk_docs(docs, DocChunk, max_chars=700)\n", - "print(\"Chunks:\", len(chunks))\n", - "print(\"Example:\", chunks[0].doc_id, chunks[0].chunk_id)" + "display(DocChunk)\n", + "# The notebook now has a typed schema for retrieved chunks." ] }, { - "cell_type": "markdown", - "id": "a4937a8e", + "cell_type": "code", + "execution_count": 94, + "id": "1d2a98c7", "metadata": { "lines_to_next_cell": 2 }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'num_docs': 8, 'num_chunks': 8, 'first_chunk': ('api', 0)}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Load the markdown documents and convert them into embedded chunks.\n", + "docs = utils.load_docs(DOCS_DIR)\n", + "chunks = utils.chunk_docs(docs, DocChunk, max_chars=700)\n", + "\n", + "display(\n", + " {\n", + " \"num_docs\": len(docs),\n", + " \"num_chunks\": len(chunks),\n", + " \"first_chunk\": (chunks[0].doc_id, chunks[0].chunk_id),\n", + " }\n", + ")\n", + "# The raw documents are now available as retrieval-ready chunks." + ] + }, + { + "cell_type": "markdown", + "id": "6d02262e", + "metadata": {}, "source": [ "## Build a lightweight search index / Retrieval\n", "\n", - "We search the chunk index for the most relevant pieces of text for a query.\n" + "- We then searche the chunk index for the most relevant pieces of text for a query\n" ] }, { "cell_type": "code", - "execution_count": 38, - "id": "124c301f", - "metadata": {}, + "execution_count": 95, + "id": "d7ef594e", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Preview matches:\n", - "billing chunk 0 score= 0.1964\n", - "api chunk 0 score= 0.0\n", - "integrations chunk 0 score= 0.0\n" - ] + "data": { + "text/plain": [ + "{'properties': {'doc_id': {'title': 'Doc Id', 'type': 'string'},\n", + " 'chunk_id': {'title': 'Chunk Id', 'type': 'integer'},\n", + " 'score': {'title': 'Score', 'type': 'number'},\n", + " 'text': {'title': 'Text', 'type': 'string'}},\n", + " 'required': ['doc_id', 'chunk_id', 'score', 'text'],\n", + " 'title': 'DocMatch',\n", + " 'type': 'object'}" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ + "# Define the schema for previewing ranked retrieval matches.\n", "class DocMatch(BaseModel):\n", " doc_id: str\n", " chunk_id: int\n", @@ -349,58 +537,146 @@ " text: str\n", "\n", "\n", - "preview = utils.search_chunks(\n", - " chunks, \"How do I download invoices?\", DocMatch, top_k=3\n", - ")\n", - "print(\"Preview matches:\")\n", - "for m in preview:\n", - " print(m.doc_id, \"chunk\", m.chunk_id, \"score=\", round(m.score, 4))" + "display(DocMatch.model_json_schema())\n", + "# Retrieval results will now have a structured schema." ] }, { - "cell_type": "markdown", - "id": "d171abf9-2350-447d-bf24-0b42013f3bce", - "metadata": {}, + "cell_type": "code", + "execution_count": 96, + "id": "be6c8ceb", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
doc_idchunk_idscoretext
0billing00.196352# Billing and Plans\\n\\nPlans\\n- Starter: $20 p...
1api00.000000# API Access\\n\\nAPI keys\\n- Create API keys un...
2integrations00.000000# Integrations\\n\\nAtlas supports S3 and Google...
\n", + "
" + ], + "text/plain": [ + " doc_id chunk_id score text\n", + "0 billing 0 0.196352 # Billing and Plans\\n\\nPlans\\n- Starter: $20 p...\n", + "1 api 0 0.000000 # API Access\\n\\nAPI keys\\n- Create API keys un...\n", + "2 integrations 0 0.000000 # Integrations\\n\\nAtlas supports S3 and Google..." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "### Importance\n", + "# Search the chunk index with a realistic support question.\n", + "preview = utils.search_chunks(\n", + " chunks,\n", + " \"How do I download invoices?\",\n", + " DocMatch,\n", + " top_k=3,\n", + ")\n", "\n", - "- We represent each document chunk as a vector and compute similarity with a query vector using dot product.\n", - "- `search_chunks(...)` ranks chunks by similarity and returns the top matches.\n" + "display(pd.DataFrame([match.model_dump() for match in preview]))\n", + "# The preview shows which document chunks rank highest for the query." ] }, { "cell_type": "markdown", - "id": "3cb9d1c8-5510-48af-9101-f18fe9b877b6", + "id": "1876052a", "metadata": {}, "source": [ "## Dependencies and Output Schema\n", "\n", - "### Dependencies (`DocDeps`)\n", - "Dependencies are runtime context passed into the agent at execution time. Here we store:\n", - "- the chunk index\n", - "- an optional user profile (for personalization)\n", - "\n", - "### Output schema (`AnswerWithSources`)\n", - "The agent output is forced into a structured format:\n", - "- `answer`: the response text\n", - "- `sources`: citations with `doc_id`, `chunk_id`, and a short quote\n", - "- `follow_up_questions`: optional list to support guardrails\n", - "\n", - "\n", - "Structured outputs eliminate brittle parsing and make results usable in real applications." + "- Dependencies are runtime context passed into the agent at execution time\n", + "- The output schema keeps answers and citations in a predictable format\n" ] }, { "cell_type": "code", - "execution_count": 41, - "id": "829b91aa-c7ba-4e76-bdb0-25bfa55fd944", - "metadata": {}, - "outputs": [], + "execution_count": 97, + "id": "a37fd3c8", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'$defs': {'SourceRef': {'properties': {'doc_id': {'title': 'Doc Id',\n", + " 'type': 'string'},\n", + " 'chunk_id': {'title': 'Chunk Id', 'type': 'integer'},\n", + " 'quote': {'title': 'Quote', 'type': 'string'}},\n", + " 'required': ['doc_id', 'chunk_id', 'quote'],\n", + " 'title': 'SourceRef',\n", + " 'type': 'object'}},\n", + " 'properties': {'answer': {'title': 'Answer', 'type': 'string'},\n", + " 'sources': {'items': {'$ref': '#/$defs/SourceRef'},\n", + " 'title': 'Sources',\n", + " 'type': 'array'},\n", + " 'follow_up_questions': {'items': {'type': 'string'},\n", + " 'title': 'Follow Up Questions',\n", + " 'type': 'array'}},\n", + " 'required': ['answer'],\n", + " 'title': 'AnswerWithSources',\n", + " 'type': 'object'}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ + "# Define the dependency and output schemas used by the agent.\n", "@dataclass\n", "class DocDeps:\n", " chunks: list[DocChunk]\n", - " user: Optional[\"UserProfile\"] = None # optional personalization\n", + " user: Optional[\"UserProfile\"] = None # Optional personalization.\n", "\n", "\n", "class SourceRef(BaseModel):\n", @@ -414,66 +690,94 @@ " sources: list[SourceRef] = Field(default_factory=list)\n", " follow_up_questions: list[str] = Field(\n", " default_factory=list\n", - " ) # enables guardrails section later\n", + " ) # Optional prompts for follow-up guidance.\n", "\n", "\n", "@dataclass\n", "class UserProfile:\n", " plan: str\n", - " region: str" + " region: str\n", + "\n", + "\n", + "display(AnswerWithSources.model_json_schema())\n", + "# The agent interface is now defined with structured dependencies and output." ] }, { "cell_type": "markdown", - "id": "307ca585-5279-4bee-96cd-571318ec350c", + "id": "783b9ddc", "metadata": {}, "source": [ "## Retrieval Tool\n", "\n", - "We wrap retrieval into a tool so the agent can call it during reasoning.\n", - "Tools are the bridge between an LLM and real functionality. Here the tool provides grounded context for RAG-style answers." + "- We then wrap the retrieval into a tool so the agent can call it during reasoning\n", + "- Tools connect the model to real functionality\n" ] }, { "cell_type": "code", - "execution_count": 42, - "id": "636e093d-7834-4cc1-a09c-9bf52d109d96", + "execution_count": 98, + "id": "fce982f5", "metadata": { "lines_to_next_cell": 2 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "search_docs_tool = functools.partial(\n", - " utils.search_docs,\n", - " doc_match_cls=DocMatch,\n", - ")" + "# Bind the retrieval helper into a tool the agent can invoke.\n", + "def search_docs_tool(\n", + " ctx: RunContext[DocDeps], query: str, top_k: int = 3\n", + "):\n", + " return utils.search_docs(\n", + " ctx,\n", + " query,\n", + " top_k=top_k,\n", + " doc_match_cls=DocMatch,\n", + " )\n", + "\n", + "display(search_docs_tool)\n", + "# The retrieval function is now packaged as a callable tool." ] }, { "cell_type": "markdown", - "id": "00d20271", + "id": "1b21a0da", "metadata": {}, "source": [ "## Agent Configuration and Validation\n", "\n", - "This agent has:\n", - "- tools: retrieval\n", - "- deps: chunk store and optional user profile\n", - "- structured output: answer plus citations\n", - "- validator: enforces citation rules and triggers retry\n", - "\n", - "The schema ensures output structure, and the validator ensures output quality. Together they turn a chatty model into a reliable system component." + "- This agent combines retrieval tools, structured outputs, and a validator that enforces citation rules\n" ] }, { "cell_type": "code", - "execution_count": 43, - "id": "c3bcc7a1-3bbc-4b0f-8537-82736d46c256", + "execution_count": 99, + "id": "c2a0b2e6", "metadata": { "lines_to_next_cell": 2 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ + "# Configure the Atlas support agent with retrieval and structured output.\n", "agent = Agent(\n", " MODEL_ID,\n", " deps_type=DocDeps,\n", @@ -482,287 +786,400 @@ " instructions=(\n", " \"You are Atlas Support. \"\n", " \"Use the `search_docs` tool to find relevant text. \"\n", - " \"Answer briefly. If you use document info, include 1-3 sources with doc_id, chunk_id, and short quotes.\"\n", + " \"Answer briefly. If you use document info, include 1-3 sources with \"\n", + " \"doc_id, chunk_id, and short quotes.\"\n", " ),\n", ")\n", - "agent.output_validator(utils.enforce_sources)" + "agent.output_validator(utils.enforce_sources)\n", + "\n", + "display(agent)\n", + "# The support agent is now ready to answer grounded questions." ] }, { "cell_type": "markdown", - "id": "0d074084", + "id": "4902e55c", "metadata": {}, "source": [ - "## End-to-End Query\n", + "- The validator runs after the model produces a schema-valid `AnswerWithSources` object\n", "\n", - "We run the agent asynchronously using `await` (notebook-safe).\n", - "\n", - "### What happened\n", - "- The agent can call `search_docs` to retrieve relevant text\n", - "- The model generates a structured response\n", - "- The validator ensures citations exist if docs were referenced\n", + "- The schema checks structure\n", + "- The validator checks reliability rules such as source coverage\n" + ] + }, + { + "cell_type": "markdown", + "id": "b101bc2c", + "metadata": {}, + "source": [ + "## End-to-End Query\n", "\n", - "This is the full pattern: RAG grounding plus structured outputs plus reliability checks." + "- Here we run the agent asynchronously\n", + "- Key pattern:\n", + " - Retrieval grounding\n", + " - Structured outputs\n", + " - Reliability checks\n" ] }, { "cell_type": "code", - "execution_count": 45, - "id": "418afe3d", - "metadata": {}, + "execution_count": 100, + "id": "dfdebc6c", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { "data": { "text/plain": [ - "AnswerWithSources(answer='Go to Settings > Billing in Atlas, find the invoice you need, and click Download. Invoices are issued on the first of each month, so you’ll find monthly invoices there.', sources=[SourceRef(doc_id='billing', chunk_id=0, quote='Invoices are issued on the first of each month.\\nYou can download invoices from Settings > Billing.')], follow_up_questions=['Do you need help finding a specific month’s invoice?'])" + "AnswerWithSources(answer='To download invoices, go to **Settings > Billing**. There, you can find and download your invoices.', sources=[SourceRef(doc_id='billing', chunk_id=0, quote='You can download invoices from Settings > Billing.')], follow_up_questions=[\"What if I can't find my invoice?\", 'Can I receive invoices via email?'])" ] }, - "execution_count": 45, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ + "# Ask an end-to-end support question using the retrieval-augmented agent.\n", "deps = DocDeps(chunks=chunks)\n", - "out = await utils.ask(\"How do I download invoices?\", deps, agent)\n", - "out" + "out = _run_async(utils.ask(\"How do I download invoices?\", deps, agent))\n", + "\n", + "display(out)\n", + "# The agent returned a structured answer object." ] }, { "cell_type": "code", - "execution_count": 46, - "id": "fc8fdf64-e860-4e51-a1ae-7dc054cc1a46", - "metadata": {}, + "execution_count": 101, + "id": "f363a29e", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Answer:\n", - " Go to Settings > Billing in Atlas, find the invoice you need, and click Download. Invoices are issued on the first of each month, so you’ll find monthly invoices there.\n", - "\n", - "Sources:\n", - "- billing (chunk 0): Invoices are issued on the first of each month.\n", - "You can download invoices from Settings > Billing.\n", - "\n", - "Follow-ups:\n", - "- Do you need help finding a specific month’s invoice?\n" - ] + "data": { + "text/markdown": [ + "### Answer\n", + "To download invoices, go to **Settings > Billing**. There, you can find and download your invoices.\n", + "\n", + "### Sources\n", + "- `billing` chunk 0: You can download invoices from Settings > Billing.\n", + "\n", + "### Follow-up questions\n", + "- What if I can't find my invoice?\n", + "- Can I receive invoices via email?" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "print(\"Answer:\\n\", out.answer)\n", - "print(\"\\nSources:\")\n", - "for s in out.sources:\n", - " print(\n", - " f\"- {s.doc_id} (chunk {s.chunk_id}): {s.quote[:120].replace('\\\\n', ' ')}\"\n", - " )\n", - "if out.follow_up_questions:\n", - " print(\"\\nFollow-ups:\")\n", - " for q in out.follow_up_questions:\n", - " print(\"-\", q)" + "# Render the answer and citations in a notebook-friendly format.\n", + "source_lines = [\n", + " f\"- `{source.doc_id}` chunk {source.chunk_id}: \"\n", + " f\"{source.quote[:120].replace(chr(10), ' ')}\"\n", + " for source in out.sources\n", + "]\n", + "follow_up_lines = [f\"- {question}\" for question in out.follow_up_questions]\n", + "answer_sections = [\n", + " \"### Answer\",\n", + " out.answer,\n", + " \"\",\n", + " \"### Sources\",\n", + " *source_lines,\n", + "]\n", + "if follow_up_lines:\n", + " answer_sections.extend([\"\", \"### Follow-up questions\", *follow_up_lines])\n", + "\n", + "display(Markdown(\"\\n\".join(answer_sections)))\n", + "# The notebook now displays the answer alongside its citations." ] }, { "cell_type": "markdown", - "id": "dd86e62b-fe33-4e30-b65e-b76054c7b256", + "id": "56f4b350", "metadata": {}, "source": [ "## Consuming Structured Output\n", "\n", - "We print the answer and citations from the structured result object. Downstream systems can store citations, audit answers, and render sources cleanly without parsing raw text." + "- Structured results help downstream systems store citations and audit answers without parsing raw text\n" ] }, { "cell_type": "code", - "execution_count": 47, - "id": "53d0b657-ba3f-4aae-a073-5c52ae052506", - "metadata": {}, + "execution_count": 102, + "id": "0c1b0b29", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validator failure example: You referenced docs/policies but did not include sources.\n" + "data": { + "text/plain": [ + "AnswerWithSources(answer='According to the policy...', sources=[], follow_up_questions=[])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Build an intentionally invalid answer object for validator inspection.\n", + "invalid_answer = AnswerWithSources(\n", + " answer=\"According to the policy...\",\n", + " sources=[],\n", + ")\n", + "\n", + "display(invalid_answer)\n", + "# This object is missing sources even though the answer claims to reference policy text." + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "b41ea3d6", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "ename": "ModelRetry", + "evalue": "You referenced docs/policies but did not include sources.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModelRetry\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[103]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Run the validator to show how it rejects unsupported document-backed claims.\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m utils.enforce_sources(invalid_answer)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/git_root/tutorials/tutorial_pydanticAI/pydanticai_example_utils.py:115\u001b[39m, in \u001b[36menforce_sources\u001b[39m\u001b[34m(result)\u001b[39m\n\u001b[32m 101\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34menforce_sources\u001b[39m(result: Any) -> Any:\n\u001b[32m 102\u001b[39m answer_l = result.answer.lower()\n\u001b[32m 103\u001b[39m mentions_docs = \u001b[38;5;28many\u001b[39m(\n\u001b[32m 104\u001b[39m tok \u001b[38;5;129;01min\u001b[39;00m answer_l\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m tok \u001b[38;5;129;01min\u001b[39;00m [\n\u001b[32m 106\u001b[39m \u001b[33m\"\u001b[39m\u001b[33maccording\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 107\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mdocs\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 108\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mdocument\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 109\u001b[39m \u001b[33m\"\u001b[39m\u001b[33msettings\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 110\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mbilling\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 111\u001b[39m \u001b[33m\"\u001b[39m\u001b[33minvoice\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 112\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mplan\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 113\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mlimit\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 114\u001b[39m ]\n\u001b[32m--> \u001b[39m\u001b[32m115\u001b[39m )\n\u001b[32m 116\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m mentions_docs \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m result.sources:\n\u001b[32m 117\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ModelRetry(\u001b[33m\"\u001b[39m\u001b[33mYou referenced docs/policies but did not include sources.\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mModelRetry\u001b[39m: You referenced docs/policies but did not include sources." ] } ], "source": [ - "try:\n", - " utils.enforce_sources(\n", - " AnswerWithSources(answer=\"According to the policy...\", sources=[])\n", - " )\n", - "except Exception as e:\n", - " print(\"Validator failure example:\", e)" + "# Run the validator to show how it rejects unsupported document-backed claims.\n", + "utils.enforce_sources(invalid_answer)" ] }, { "cell_type": "markdown", - "id": "a1100208-fa7a-4d38-9a01-74070d293143", + "id": "0ee21a65", "metadata": {}, "source": [ - "### What happened (and why PydanticAI helps)\n", + "### What happened\n", "\n", - "This shows the validator catching an invalid output.\n", - "In a real run, `ModelRetry` tells PydanticAI to retry until the output meets the citation rules." + "- The validator raises `ModelRetry` when an answer cites documentation without including sources\n" ] }, { "cell_type": "markdown", - "id": "b9b94593", + "id": "222a785b", "metadata": {}, "source": [ "## Streaming Output\n", "\n", - "Streaming returns tokens progressively, which improves perceived latency in chat interfaces.\n", + "- Streaming returns tokens progressively\n", + "- Progressive output improves perceived latency in chat interfaces\n" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "id": "98b021cf", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Agent(model=OpenAIChatModel(), name=None, end_strategy='early', model_settings=None, output_type=, instrument=None)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Create a small streaming agent for a short demonstration.\n", + "stream_agent = Agent(\n", + " MODEL_ID,\n", + " instructions=\"Write one short paragraph about unit tests.\",\n", + ")\n", "\n", - "Streaming is useful for UI experiences and interactive assistants, especially when responses are longer." + "display(stream_agent)\n", + "# The streaming demonstration agent is now configured." ] }, { "cell_type": "code", - "execution_count": 50, - "id": "9ae13d6b", - "metadata": {}, + "execution_count": 105, + "id": "48134259", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Unit tests are small, automated checks that verify a single, isolated piece of code—usually a functionUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended.Unit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly andUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly and often, catching regressions earlyUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly and often, catching regressions early and enabling safe refactoring. Effective unit tests areUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly and often, catching regressions early and enabling safe refactoring. Effective unit tests are deterministic, focus on one behaviorUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly and often, catching regressions early and enabling safe refactoring. Effective unit tests are deterministic, focus on one behavior, and isolate external dependencies with mocks or stubsUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly and often, catching regressions early and enabling safe refactoring. Effective unit tests are deterministic, focus on one behavior, and isolate external dependencies with mocks or stubs. A solid unit test suite documents intended behavior, improves codeUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly and often, catching regressions early and enabling safe refactoring. Effective unit tests are deterministic, focus on one behavior, and isolate external dependencies with mocks or stubs. A solid unit test suite documents intended behavior, improves code quality, and speeds up developmentUnit tests are small, automated checks that verify a single, isolated piece of code—usually a function or class—behaves as intended. Written and run by developers, they execute quickly and often, catching regressions early and enabling safe refactoring. Effective unit tests are deterministic, focus on one behavior, and isolate external dependencies with mocks or stubs. A solid unit test suite documents intended behavior, improves code quality, and speeds up development.\n", + "Unit tests are aUnit tests are a type of automated software testing designed to validateUnit tests are a type of automated software testing designed to validate individual components or functions of a program in isolation. By testing these small units of code independentlyUnit tests are a type of automated software testing designed to validate individual components or functions of a program in isolation. By testing these small units of code independently, developers can ensure that each part behaves as expectedUnit tests are a type of automated software testing designed to validate individual components or functions of a program in isolation. By testing these small units of code independently, developers can ensure that each part behaves as expected, which helps identify bugs earlyUnit tests are a type of automated software testing designed to validate individual components or functions of a program in isolation. By testing these small units of code independently, developers can ensure that each part behaves as expected, which helps identify bugs early in the development process. Unit tests areUnit tests are a type of automated software testing designed to validate individual components or functions of a program in isolation. By testing these small units of code independently, developers can ensure that each part behaves as expected, which helps identify bugs early in the development process. Unit tests are typically written alongside the code they test and provide a safety net forUnit tests are a type of automated software testing designed to validate individual components or functions of a program in isolation. By testing these small units of code independently, developers can ensure that each part behaves as expected, which helps identify bugs early in the development process. Unit tests are typically written alongside the code they test and provide a safety net for refactoring, allowing developers toUnit tests are a type of automated software testing designed to validate individual components or functions of a program in isolation. By testing these small units of code independently, developers can ensure that each part behaves as expected, which helps identify bugs early in the development process. Unit tests are typically written alongside the code they test and provide a safety net for refactoring, allowing developers to make changes with confidence that existing functionality will remain intact. They are an essential practice in agile development and continuous integration, promoting better code quality and maintainability.\n", "\n" ] } ], "source": [ - "stream_agent = Agent(\n", - " MODEL_ID, instructions=\"Write one short paragraph about unit tests.\"\n", - ")\n", - "await utils.stream_demo(stream_agent)" + "# Stream a short response into the notebook output area.\n", + "_run_async(utils.stream_demo(stream_agent))" ] }, { "cell_type": "markdown", - "id": "2ae2be45", + "id": "b829ea0e", "metadata": {}, "source": [ "## Conversation memory (multi-turn)\n", "\n", - "Reuse message history to keep context across turns.\n" + "- Reuse message history to keep context across turns\n" ] }, { "cell_type": "code", - "execution_count": 52, - "id": "9e55f9e2", + "execution_count": 106, + "id": "5d8619a6", "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "answer='No—2FA is only available on Team and Enterprise plans.' sources=[SourceRef(doc_id='security', chunk_id=0, quote='Atlas supports two-factor authentication (2FA) for Team and Enterprise plans.'), SourceRef(doc_id='billing', chunk_id=0, quote='Plans - Starter: $20 per month... - Team: $80 per month... - Enterprise: custom pricing, SSO, dedicated success manager.')] follow_up_questions=['Would you like help upgrading to the Team plan to enable 2FA?']\n" - ] + "data": { + "text/plain": [ + "AnswerWithSources(answer='You can enable two-factor authentication (2FA) under **Settings > Security**. This feature is available for Team and Enterprise plans.', sources=[SourceRef(doc_id='security', chunk_id=0, quote='Enable it under Settings > Security.')], follow_up_questions=[])" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ + "# Ask an initial question and validate the grounded response.\n", "deps = DocDeps(chunks=chunks)\n", - "first = await agent.run(\"Where do I enable 2FA?\", deps=deps)\n", + "first = _run_async(agent.run(\"Where do I enable 2FA?\", deps=deps))\n", "utils.enforce_sources(first.output)\n", - "follow_up = await agent.run(\n", - " \"Does that work on the Starter plan?\",\n", - " deps=deps,\n", - " message_history=first.new_messages(),\n", + "\n", + "display(first.output)\n", + "# The first turn establishes grounded context for the next question." + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "370b1ad4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AnswerWithSources(answer='No, two-factor authentication (2FA) is not available on the Starter plan. It is only supported for Team and Enterprise plans.', sources=[SourceRef(doc_id='security', chunk_id=0, quote='Atlas supports two-factor authentication (2FA) for Team and Enterprise plans.')], follow_up_questions=[])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Reuse the first turn's message history in a follow-up question.\n", + "follow_up = _run_async(\n", + " agent.run(\n", + " \"Does that work on the Starter plan?\",\n", + " deps=deps,\n", + " message_history=first.new_messages(),\n", + " )\n", ")\n", "utils.enforce_sources(follow_up.output)\n", - "print(follow_up.output)" + "\n", + "display(follow_up.output)\n", + "# The follow-up answer reuses prior context through message history." ] }, { "cell_type": "markdown", - "id": "d683f0b1", + "id": "dd4eed90", "metadata": {}, "source": [ "## Guardrails (lightweight)\n", "\n", - "Reject out-of-scope questions without calling the model.\n" + "- Reject out-of-scope questions without calling the model\n" ] }, { "cell_type": "code", - "execution_count": 48, - "id": "97db2a0b", + "execution_count": 108, + "id": "e90d77c8", "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "answer='I can only help with Atlas product documentation and support questions.' sources=[] follow_up_questions=['Do you have a question about Atlas setup, billing, or support?']\n" - ] + "data": { + "text/plain": [ + "AnswerWithSources(answer='I can only help with Atlas product documentation and support questions.', sources=[], follow_up_questions=['Do you have a question about Atlas setup, billing, or support?'])" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "guarded = await utils.run_guarded(\n", - " \"Write me a poem about the ocean.\",\n", - " DocDeps(chunks=chunks),\n", - " agent,\n", - " AnswerWithSources,\n", + "# Run a guardrail check against an out-of-scope prompt.\n", + "guarded = _run_async(\n", + " utils.run_guarded(\n", + " \"Write me a poem about the ocean.\",\n", + " DocDeps(chunks=chunks),\n", + " agent,\n", + " AnswerWithSources,\n", + " )\n", ")\n", - "print(guarded)" + "\n", + "display(guarded)\n", + "# The guardrail returns a bounded response without invoking the main workflow." ] }, { "cell_type": "markdown", - "id": "72992f9d", + "id": "0594b1b6", "metadata": {}, "source": [ "## Dynamic updates\n", "\n", - "Add new docs, rebuild the index, and query again.\n" + "- Add new docs, rebuild the index, and query again\n" ] }, { "cell_type": "code", - "execution_count": 57, - "id": "66f650bf-4600-4f7a-89d7-775f55aafac4", + "execution_count": 109, + "id": "8a562607", "metadata": { "lines_to_next_cell": 2 }, - "outputs": [], - "source": [ - "from pathlib import Path" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "49e73d96", - "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - " 8\n", - "First item type: \n", - "First item preview: {'doc_id': 'api', 'title': 'Api', 'text': '# API Access\\n\\nAPI keys\\n- Create API keys under Settings > Developer.\\n- Ke\n", - "Answer:\n", - " Yes—Atlas supports Amazon S3 as a data source.\n", - "\n", - "Sources:\n", - "- integrations (chunk 0): Atlas supports S3 and Google Cloud Storage as data sources.\n" - ] + "data": { + "text/plain": [ + "PosixPath('example_dataset/integrations.md')" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "from pathlib import Path\n", - "\n", - "# 1) Add the new doc\n", + "# Add a new support document to the local knowledge base.\n", "new_doc = DOCS_DIR / \"integrations.md\"\n", "new_doc.write_text(\n", " \"\"\"\n", @@ -775,82 +1192,146 @@ " encoding=\"utf-8\",\n", ")\n", "\n", - "# 2) Reload docs in the expected dict format\n", - "docs = utils.load_docs(DOCS_DIR) # must return list[dict] with doc_id/title/text\n", + "display(new_doc)\n", + "# The knowledge base now includes an integrations document." + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "id": "d95839a2", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'num_docs': 8, 'num_chunks': 8}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Reload the documents and rebuild the retrieval chunks.\n", + "docs = utils.load_docs(DOCS_DIR)\n", "chunks = utils.chunk_docs(docs, DocChunk, max_chars=700)\n", "\n", - "# 3) Run the agent (notebook-safe)\n", + "display({\"num_docs\": len(docs), \"num_chunks\": len(chunks)})\n", + "# The retrieval index now includes the newly added document." + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "id": "4f7581d3", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AnswerWithSources(answer='Yes, we support S3 as a data source.', sources=[SourceRef(doc_id='integrations', chunk_id=0, quote='Atlas supports S3 and Google Cloud Storage as data sources.')], follow_up_questions=['What is S3 integration used for?', 'Are there any limitations with S3 support?'])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Query the updated knowledge base about integrations support.\n", "deps = DocDeps(chunks=chunks)\n", - "\n", - "res = await agent.run(\"Do you support S3?\", deps=deps)\n", + "res = _run_async(agent.run(\"Do you support S3?\", deps=deps))\n", "out = res.output\n", "\n", - "print(\"Answer:\\n\", out.answer)\n", - "print(\"\\nSources:\")\n", - "for s in out.sources:\n", - " print(\n", - " f\"- {s.doc_id} (chunk {s.chunk_id}): {s.quote[:120].replace('\\\\n', ' ')}\"\n", - " )" + "display(out)\n", + "# The updated index returns a grounded answer about S3 support." ] }, { "cell_type": "markdown", - "id": "70b34275", + "id": "483569de", "metadata": {}, "source": [ "## Personalization via Dependencies\n", "\n", - "We pass a `UserProfile` through dependencies so the agent can tailor answers. Dependencies are the clean way to inject user context, tenant context, and configuration into tools and agent behavior without global state or prompt hacks." + "- Here we pass a `UserProfile` through dependencies so the agent can tailor answers\n", + "- Dependencies are a clean way to inject user context, tenant context, and configuration into tools\n" ] }, { "cell_type": "code", - "execution_count": 49, - "id": "e882fddd", - "metadata": {}, + "execution_count": 112, + "id": "21c448c3", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [ { "data": { "text/plain": [ - "AnswerWithSources(answer='Here are the current Atlas limits by plan:\\n- Starter: 30 API requests per minute; 10 GB total storage.\\n- Team: 120 API requests per minute; 200 GB total storage.\\n\\nWhich plan are you on? I can confirm the exact limits for your workspace.', sources=[SourceRef(doc_id='limits', chunk_id=0, quote='API requests are limited to 120 per minute on Team. Starter is limited to 30 per minute. Starter: 10 GB total storage. Team: 200 GB total storage.')], follow_up_questions=['Which plan is your workspace on (Starter or Team)?'])" + "{'user': UserProfile(plan='Starter', region='US'),\n", + " 'num_chunks': 8,\n", + " 'sample_chunk': ('api', 0)}" ] }, - "execution_count": 49, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ + "# Create personalized dependencies for a Starter-plan user.\n", "personalized_deps = DocDeps(\n", " chunks=chunks,\n", " user=UserProfile(plan=\"Starter\", region=\"US\"),\n", ")\n", "\n", - "personalized = await utils.ask(\n", - " \"What are my rate limits and storage limits?\",\n", - " personalized_deps,\n", - " agent,\n", + "display(\n", + " {\n", + " \"user\": personalized_deps.user,\n", + " \"num_chunks\": len(personalized_deps.chunks),\n", + " \"sample_chunk\": (\n", + " personalized_deps.chunks[0].doc_id,\n", + " personalized_deps.chunks[0].chunk_id,\n", + " ),\n", + " }\n", ")\n", - "\n", - "personalized" + "# The personalized dependency summary is easier to inspect than the full chunk payload." ] }, { - "cell_type": "markdown", - "id": "9111420d", - "metadata": {}, + "cell_type": "code", + "execution_count": 113, + "id": "74d03a82", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AnswerWithSources(answer='- **Rate Limits**: \\n - Team: 120 API requests per minute. \\n - Starter: 30 API requests per minute. \\n\\n- **Storage Limits**: \\n - Starter: 10 GB total storage. \\n - Team: 200 GB total storage.', sources=[SourceRef(doc_id='limits', chunk_id=0, quote='Rate limits\\n- API requests are limited to 120 per minute on Team.\\n- Starter is limited to 30 per minute.\\n\\nStorage\\n- Starter: 10 GB total storage.\\n- Team: 200 GB total storage.')], follow_up_questions=[])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# Summary\n", - "\n", - "You built a grounded support assistant using:\n", - "- a synthetic knowledge base\n", - "- deterministic local embeddings for retrieval\n", - "- PydanticAI tools to fetch context\n", - "- structured outputs with citations\n", - "- validators to enforce reliability\n", - "- optional guardrails and personalization\n", + "# Ask a question that depends on the supplied user profile.\n", + "personalized = _run_async(\n", + " utils.ask(\n", + " \"What are my rate limits and storage limits?\",\n", + " personalized_deps,\n", + " agent,\n", + " )\n", + ")\n", "\n", - "This is the core E2E pattern for building production-grade assistants with PydanticAI." + "display(personalized)\n", + "# The final answer can now reflect user-specific context." ] } ], diff --git a/tutorials/tutorial_pydanticAI/pydanticai.example.py b/tutorials/tutorial_pydanticAI/pydanticai.example.py index e76ffc983..d4b1b67e0 100644 --- a/tutorials/tutorial_pydanticAI/pydanticai.example.py +++ b/tutorials/tutorial_pydanticAI/pydanticai.example.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.19.0 +# jupytext_version: 1.19.1 # kernelspec: # display_name: Python 3 (ipykernel) # language: python @@ -17,99 +17,127 @@ # %load_ext autoreload # %autoreload 2 +# System libraries. import logging +# Third party libraries. +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from dotenv import find_dotenv, load_dotenv -import helpers.hnotebook as ut -ut.config_notebook() +# %% +# Import notebook-specific libraries. +import asyncio +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from IPython.display import Markdown, display +import nest_asyncio +from pydantic import BaseModel, Field +from pydantic_ai import Agent, RunContext + +import helpers.hio as hio -# Initialize logger. -logging.basicConfig(level=logging.INFO) -_LOG = logging.getLogger(__name__) # %% +import logging + +# Local utility. import pydanticai_example_utils as utils +_LOG = logging.getLogger(__name__) +utils.init_logger(_LOG) + +display("Notebook logging initialized.") +# Notebook and utility logging are now configured. + + # %% [markdown] +# # Summary +# +# - This notebook shows how to build a grounded Atlas support assistant with retrieval, structured outputs, validation, guardrails, and personalization +# # # PydanticAI Example Notebook: Atlas Support Assistant (E2E) # -# This notebook builds a small "support assistant" for a synthetic product called **Atlas**. +# - Goal: build a small support assistant for the synthetic product **Atlas** +# - Workflow: +# - Generate a synthetic knowledge base +# - Load and chunk the docs +# - Build a local embedding index +# - Add retrieval as a **PydanticAI** tool +# - Use structured outputs with citations +# - Add validators, guardrails, and personalization # -# We will: -# 1. Generate a synthetic knowledge base (Markdown docs) -# 2. Load + chunk the docs -# 3. Build a simple local embedding index (no external embedding service required) -# 4. Add retrieval as a **PydanticAI tool** -# 5. Use **structured outputs** (Pydantic schema) with **citations** -# 6. Add **validators** to enforce rules like "citations required" -# 7. Add optional **guardrails** and **personalization** +# - Outcome: an end-to-end pattern you can reuse for real retrieval-augmented assistants # -# The result is an end-to-end pattern you can reuse for real RAG assistants. # %% [markdown] # ## Setup # -# This cell initializes the environment and imports all required libraries. -# -# PydanticAI agents need: -# - a model identifier (for example `openai:gpt-4o-mini`) -# - a provider API key (for example `OPENAI_API_KEY`) -# -# Everything else in this notebook is local and self-contained. +# - `PydanticAI` agents need: +# - A model identifier, such as `openai:gpt-4o-mini` +# - A provider API key, such as `OPENAI_API_KEY` +# - Create a ```.env``` file containing these variables to be called in the notebook +# - Everything else in this notebook is local and self-contained # # %% -# !pip install -q pydantic-ai +# Enable nested event loops so async agent calls run inside the notebook. +nest_asyncio.apply() -# %% -import os -import functools -from pathlib import Path -from dataclasses import dataclass -from typing import Optional +display("Nested event loop support enabled.") +# Async notebook execution is now configured. -import nest_asyncio -nest_asyncio.apply() +# %% +# Run notebook coroutines through the current event loop so the paired Python file compiles. +def _run_async(awaitable): + return asyncio.get_event_loop().run_until_complete(awaitable) -from pydantic import BaseModel, Field -from pydantic_ai import Agent -MODEL_ID = os.getenv("PYDANTIC_AI_MODEL", "openai:gpt-4o-mini") -print("MODEL_ID:", MODEL_ID) -print("OPENAI_API_KEY set:", bool(os.getenv("OPENAI_API_KEY"))) +display(_run_async) +# Notebook async calls can now run without top-level await statements. +# %% +# Load environment variables from a local dotenv file if one exists. +env_path = find_dotenv(usecwd=True) +load_dotenv(env_path, override=True) +_LOG.info("dotenv path: %s", env_path or "") +env_path or "" +# Environment variables are available to the model configuration cells. + # %% [markdown] # ## Data and Scenario # -# We build a tiny product docs corpus to keep the tutorial self-contained. -# -# We will build a tiny documentation set for an imaginary product called **Atlas**. +# - This notebook uses a small product-docs corpus to stay self-contained +# - The corpus describes an imaginary product called **Atlas** # # %% [markdown] # ### What this cell does # -# - Creates a local folder `example_dataset/` and writes a small set of **synthetic product/support documents** as Markdown files. -# - Each file represents a support knowledge-base article (billing, troubleshooting, security, limits, etc.). -# - The dataset is intentionally small but diverse so retrieval can return the *right* document depending on the question. -# -# ### Importance -# -# PydanticAI becomes most useful when the agent is grounded in external context (RAG-style). -# These documents act as that context. In the next steps, we will: +# - Creates a local folder `example_dataset/` and writes a small set of synthetic support documents +# - Uses one file per support knowledge-base article +# - Keeps the dataset small so retrieval behavior stays easy to inspect # -# 1. Load these Markdown files into memory -# 2. Retrieve relevant chunks for a user query -# 3. Use a PydanticAI agent + tools to answer using retrieved text -# 4. Return a structured output with citations # %% +# Create the local directory that stores the synthetic support documents. DOCS_DIR = Path("example_dataset/") -DOCS_DIR.mkdir(parents=True, exist_ok=True) +hio.create_dir(str(DOCS_DIR), incremental=True) +display(DOCS_DIR) +# The example dataset directory is now available. + + +# %% +# Define the synthetic Atlas support documents used throughout the notebook. DOCS = { "overview.md": """ # Atlas Overview @@ -179,40 +207,42 @@ """, } +display(sorted(DOCS)) +# The notebook now has a compact synthetic document corpus. + + +# %% +# Materialize the synthetic documents on disk if they do not already exist. for name, text in DOCS.items(): path = DOCS_DIR / name if not path.exists(): - path.write_text(text.strip() + "\n") + path.write_text(text.strip() + "\n", encoding="utf-8") -print("Docs directory:", DOCS_DIR) -print("Files:", [p.name for p in DOCS_DIR.glob("*.md")]) +display(sorted(p.name for p in DOCS_DIR.glob("*.md"))) +# The synthetic knowledge-base files are now stored on disk. # %% [markdown] -# We load all Markdown files into a standard in-memory format: +# - We load Markdown files into a standard in-memory format: # -# - `doc_id`: stable identifier for citations -# - `title`: human-readable name -# - `text`: document content +# - `doc_id`: stable identifier for citations +# - `title`: human-readable name +# - `text`: document content +# +# - A consistent document schema makes it easier to build retrieval tools and return structured citations # -# A consistent document schema makes it easy to: -# - pass documents into dependencies (`deps`) -# - build retrieval tools -# - return structured citations in the agent output # %% [markdown] # ## Chunking and Local Embeddings # -# We split each document into chunks and compute a deterministic vector for each chunk. -# -# ### Why this approach -# - It is fully local and reproducible (no external embedding API required) -# - It is good enough to demonstrate retrieval and grounding +# - We split each document into chunks and computes a deterministic vector for each chunk +# - This helps: +# - Ensure it is fully local and reproducible +# - Ensure it is good enough to demonstrate retrieval and grounding # -# ### Importance -# PydanticAI agents become far more reliable when they can retrieve relevant context via tools instead of guessing. # %% +# Define the chunk schema used for retrieval and citations. @dataclass class DocChunk: doc_id: str @@ -221,20 +251,33 @@ class DocChunk: vector: list[float] +display(DocChunk) +# The notebook now has a typed schema for retrieved chunks. + + +# %% +# Load the markdown documents and convert them into embedded chunks. docs = utils.load_docs(DOCS_DIR) chunks = utils.chunk_docs(docs, DocChunk, max_chars=700) -print("Chunks:", len(chunks)) -print("Example:", chunks[0].doc_id, chunks[0].chunk_id) + +display( + { + "num_docs": len(docs), + "num_chunks": len(chunks), + "first_chunk": (chunks[0].doc_id, chunks[0].chunk_id), + } +) +# The raw documents are now available as retrieval-ready chunks. # %% [markdown] # ## Build a lightweight search index / Retrieval # -# We search the chunk index for the most relevant pieces of text for a query. +# - We then searche the chunk index for the most relevant pieces of text for a query # - # %% +# Define the schema for previewing ranked retrieval matches. class DocMatch(BaseModel): doc_id: str chunk_id: int @@ -242,43 +285,36 @@ class DocMatch(BaseModel): text: str +display(DocMatch.model_json_schema()) +# Retrieval results will now have a structured schema. + + +# %% +# Search the chunk index with a realistic support question. preview = utils.search_chunks( - chunks, "How do I download invoices?", DocMatch, top_k=3 + chunks, + "How do I download invoices?", + DocMatch, + top_k=3, ) -print("Preview matches:") -for m in preview: - print(m.doc_id, "chunk", m.chunk_id, "score=", round(m.score, 4)) +display(pd.DataFrame([match.model_dump() for match in preview])) +# The preview shows which document chunks rank highest for the query. -# %% [markdown] -# ### Importance -# -# - We represent each document chunk as a vector and compute similarity with a query vector using dot product. -# - `search_chunks(...)` ranks chunks by similarity and returns the top matches. -# # %% [markdown] # ## Dependencies and Output Schema # -# ### Dependencies (`DocDeps`) -# Dependencies are runtime context passed into the agent at execution time. Here we store: -# - the chunk index -# - an optional user profile (for personalization) +# - Dependencies are runtime context passed into the agent at execution time +# - The output schema keeps answers and citations in a predictable format # -# ### Output schema (`AnswerWithSources`) -# The agent output is forced into a structured format: -# - `answer`: the response text -# - `sources`: citations with `doc_id`, `chunk_id`, and a short quote -# - `follow_up_questions`: optional list to support guardrails -# -# -# Structured outputs eliminate brittle parsing and make results usable in real applications. # %% +# Define the dependency and output schemas used by the agent. @dataclass class DocDeps: chunks: list[DocChunk] - user: Optional["UserProfile"] = None # optional personalization + user: Optional["UserProfile"] = None # Optional personalization. class SourceRef(BaseModel): @@ -292,7 +328,7 @@ class AnswerWithSources(BaseModel): sources: list[SourceRef] = Field(default_factory=list) follow_up_questions: list[str] = Field( default_factory=list - ) # enables guardrails section later + ) # Optional prompts for follow-up guidance. @dataclass @@ -301,31 +337,41 @@ class UserProfile: region: str +display(AnswerWithSources.model_json_schema()) +# The agent interface is now defined with structured dependencies and output. + + # %% [markdown] # ## Retrieval Tool # -# We wrap retrieval into a tool so the agent can call it during reasoning. -# Tools are the bridge between an LLM and real functionality. Here the tool provides grounded context for RAG-style answers. +# - We then wrap the retrieval into a tool so the agent can call it during reasoning +# - Tools connect the model to real functionality +# # %% -search_docs_tool = functools.partial( - utils.search_docs, - doc_match_cls=DocMatch, -) +# Bind the retrieval helper into a tool the agent can invoke. +def search_docs_tool( + ctx: RunContext[DocDeps], query: str, top_k: int = 3 +): + return utils.search_docs( + ctx, + query, + top_k=top_k, + doc_match_cls=DocMatch, + ) + +display(search_docs_tool) +# The retrieval function is now packaged as a callable tool. # %% [markdown] # ## Agent Configuration and Validation # -# This agent has: -# - tools: retrieval -# - deps: chunk store and optional user profile -# - structured output: answer plus citations -# - validator: enforces citation rules and triggers retry +# - This agent combines retrieval tools, structured outputs, and a validator that enforces citation rules # -# The schema ensures output structure, and the validator ensures output quality. Together they turn a chatty model into a reliable system component. # %% +# Configure the Atlas support agent with retrieval and structured output. agent = Agent( MODEL_ID, deps_type=DocDeps, @@ -334,122 +380,175 @@ class UserProfile: instructions=( "You are Atlas Support. " "Use the `search_docs` tool to find relevant text. " - "Answer briefly. If you use document info, include 1-3 sources with doc_id, chunk_id, and short quotes." + "Answer briefly. If you use document info, include 1-3 sources with " + "doc_id, chunk_id, and short quotes." ), ) agent.output_validator(utils.enforce_sources) +display(agent) +# The support agent is now ready to answer grounded questions. + # %% [markdown] -# ## End-to-End Query +# - The validator runs after the model produces a schema-valid `AnswerWithSources` object # -# We run the agent asynchronously using `await` (notebook-safe). +# - The schema checks structure +# - The validator checks reliability rules such as source coverage # -# ### What happened -# - The agent can call `search_docs` to retrieve relevant text -# - The model generates a structured response -# - The validator ensures citations exist if docs were referenced + +# %% [markdown] +# ## End-to-End Query +# +# - Here we run the agent asynchronously +# - Key pattern: +# - Retrieval grounding +# - Structured outputs +# - Reliability checks # -# This is the full pattern: RAG grounding plus structured outputs plus reliability checks. # %% +# Ask an end-to-end support question using the retrieval-augmented agent. deps = DocDeps(chunks=chunks) -out = await utils.ask("How do I download invoices?", deps, agent) -out +out = _run_async(utils.ask("How do I download invoices?", deps, agent)) + +display(out) +# The agent returned a structured answer object. + # %% -print("Answer:\n", out.answer) -print("\nSources:") -for s in out.sources: - print( - f"- {s.doc_id} (chunk {s.chunk_id}): {s.quote[:120].replace('\\n', ' ')}" - ) -if out.follow_up_questions: - print("\nFollow-ups:") - for q in out.follow_up_questions: - print("-", q) +# Render the answer and citations in a notebook-friendly format. +source_lines = [ + f"- `{source.doc_id}` chunk {source.chunk_id}: " + f"{source.quote[:120].replace(chr(10), ' ')}" + for source in out.sources +] +follow_up_lines = [f"- {question}" for question in out.follow_up_questions] +answer_sections = [ + "### Answer", + out.answer, + "", + "### Sources", + *source_lines, +] +if follow_up_lines: + answer_sections.extend(["", "### Follow-up questions", *follow_up_lines]) + +display(Markdown("\n".join(answer_sections))) +# The notebook now displays the answer alongside its citations. + # %% [markdown] # ## Consuming Structured Output # -# We print the answer and citations from the structured result object. Downstream systems can store citations, audit answers, and render sources cleanly without parsing raw text. +# - Structured results help downstream systems store citations and audit answers without parsing raw text +# # %% -try: - utils.enforce_sources( - AnswerWithSources(answer="According to the policy...", sources=[]) - ) -except Exception as e: - print("Validator failure example:", e) +# Build an intentionally invalid answer object for validator inspection. +invalid_answer = AnswerWithSources( + answer="According to the policy...", + sources=[], +) + +display(invalid_answer) +# This object is missing sources even though the answer claims to reference policy text. + + +# %% +# Run the validator to show how it rejects unsupported document-backed claims. +utils.enforce_sources(invalid_answer) + # %% [markdown] -# ### What happened (and why PydanticAI helps) +# ### What happened +# +# - The validator raises `ModelRetry` when an answer cites documentation without including sources # -# This shows the validator catching an invalid output. -# In a real run, `ModelRetry` tells PydanticAI to retry until the output meets the citation rules. # %% [markdown] # ## Streaming Output # -# Streaming returns tokens progressively, which improves perceived latency in chat interfaces. +# - Streaming returns tokens progressively +# - Progressive output improves perceived latency in chat interfaces # -# Streaming is useful for UI experiences and interactive assistants, especially when responses are longer. # %% +# Create a small streaming agent for a short demonstration. stream_agent = Agent( - MODEL_ID, instructions="Write one short paragraph about unit tests." + MODEL_ID, + instructions="Write one short paragraph about unit tests.", ) -await utils.stream_demo(stream_agent) + +display(stream_agent) +# The streaming demonstration agent is now configured. + + +# %% +# Stream a short response into the notebook output area. +_run_async(utils.stream_demo(stream_agent)) + # %% [markdown] # ## Conversation memory (multi-turn) # -# Reuse message history to keep context across turns. +# - Reuse message history to keep context across turns # # %% +# Ask an initial question and validate the grounded response. deps = DocDeps(chunks=chunks) -first = await agent.run("Where do I enable 2FA?", deps=deps) +first = _run_async(agent.run("Where do I enable 2FA?", deps=deps)) utils.enforce_sources(first.output) -follow_up = await agent.run( - "Does that work on the Starter plan?", - deps=deps, - message_history=first.new_messages(), + +display(first.output) +# The first turn establishes grounded context for the next question. + + +# %% +# Reuse the first turn's message history in a follow-up question. +follow_up = _run_async( + agent.run( + "Does that work on the Starter plan?", + deps=deps, + message_history=first.new_messages(), + ) ) utils.enforce_sources(follow_up.output) -print(follow_up.output) + +display(follow_up.output) +# The follow-up answer reuses prior context through message history. # %% [markdown] # ## Guardrails (lightweight) # -# Reject out-of-scope questions without calling the model. +# - Reject out-of-scope questions without calling the model # # %% -guarded = await utils.run_guarded( - "Write me a poem about the ocean.", - DocDeps(chunks=chunks), - agent, - AnswerWithSources, +# Run a guardrail check against an out-of-scope prompt. +guarded = _run_async( + utils.run_guarded( + "Write me a poem about the ocean.", + DocDeps(chunks=chunks), + agent, + AnswerWithSources, + ) ) -print(guarded) + +display(guarded) +# The guardrail returns a bounded response without invoking the main workflow. # %% [markdown] # ## Dynamic updates # -# Add new docs, rebuild the index, and query again. +# - Add new docs, rebuild the index, and query again # # %% -from pathlib import Path - - -# %% -from pathlib import Path - -# 1) Add the new doc +# Add a new support document to the local knowledge base. new_doc = DOCS_DIR / "integrations.md" new_doc.write_text( """ @@ -462,51 +561,66 @@ class UserProfile: encoding="utf-8", ) -# 2) Reload docs in the expected dict format -docs = utils.load_docs(DOCS_DIR) # must return list[dict] with doc_id/title/text +display(new_doc) +# The knowledge base now includes an integrations document. + + +# %% +# Reload the documents and rebuild the retrieval chunks. +docs = utils.load_docs(DOCS_DIR) chunks = utils.chunk_docs(docs, DocChunk, max_chars=700) -# 3) Run the agent (notebook-safe) -deps = DocDeps(chunks=chunks) +display({"num_docs": len(docs), "num_chunks": len(chunks)}) +# The retrieval index now includes the newly added document. -res = await agent.run("Do you support S3?", deps=deps) + +# %% +# Query the updated knowledge base about integrations support. +deps = DocDeps(chunks=chunks) +res = _run_async(agent.run("Do you support S3?", deps=deps)) out = res.output -print("Answer:\n", out.answer) -print("\nSources:") -for s in out.sources: - print( - f"- {s.doc_id} (chunk {s.chunk_id}): {s.quote[:120].replace('\\n', ' ')}" - ) +display(out) +# The updated index returns a grounded answer about S3 support. + # %% [markdown] # ## Personalization via Dependencies # -# We pass a `UserProfile` through dependencies so the agent can tailor answers. Dependencies are the clean way to inject user context, tenant context, and configuration into tools and agent behavior without global state or prompt hacks. +# - Here we pass a `UserProfile` through dependencies so the agent can tailor answers +# - Dependencies are a clean way to inject user context, tenant context, and configuration into tools +# # %% +# Create personalized dependencies for a Starter-plan user. personalized_deps = DocDeps( chunks=chunks, user=UserProfile(plan="Starter", region="US"), ) -personalized = await utils.ask( - "What are my rate limits and storage limits?", - personalized_deps, - agent, +display( + { + "user": personalized_deps.user, + "num_chunks": len(personalized_deps.chunks), + "sample_chunk": ( + personalized_deps.chunks[0].doc_id, + personalized_deps.chunks[0].chunk_id, + ), + } ) +# The personalized dependency summary is easier to inspect than the full chunk payload. -personalized -# %% [markdown] -# # Summary -# -# You built a grounded support assistant using: -# - a synthetic knowledge base -# - deterministic local embeddings for retrieval -# - PydanticAI tools to fetch context -# - structured outputs with citations -# - validators to enforce reliability -# - optional guardrails and personalization -# -# This is the core E2E pattern for building production-grade assistants with PydanticAI. +# %% +# Ask a question that depends on the supplied user profile. +personalized = _run_async( + utils.ask( + "What are my rate limits and storage limits?", + personalized_deps, + agent, + ) +) + +display(personalized) +# The final answer can now reflect user-specific context. + diff --git a/tutorials/tutorial_pydanticAI/pydanticai_API_utils.py b/tutorials/tutorial_pydanticAI/pydanticai_API_utils.py index f91e9c456..349049be7 100644 --- a/tutorials/tutorial_pydanticAI/pydanticai_API_utils.py +++ b/tutorials/tutorial_pydanticAI/pydanticai_API_utils.py @@ -1,14 +1,53 @@ -"""Utility functions for tutorials/tutorial_pydanticAI/pydanticai.API notebook.""" +""" +Utility functions for tutorials/tutorial_pydanticAI/pydanticai.API notebook. +Import as: + +import tutorials.tutorial_pydanticAI.pydanticai_API_utils as ttppaput +""" + +import importlib +import importlib.util +import inspect +import logging +import os +from pathlib import Path from typing import Any from pydantic_ai import ModelRetry, RunContext +import helpers.hdbg as hdbg +import helpers.hnotebook as hnotebo + +_LOG = logging.getLogger(__name__) +_DOCUMENTS_CACHE: dict[str, str] | None = None -# ######################################################################### + +# ############################################################################# # Code for setup and masking. -# ######################################################################### +# ############################################################################# +def init_logger(notebook_log: logging.Logger) -> None: + """ + Initialize notebook and utility logging. + + :param notebook_log: logger from the paired notebook + """ + global _LOG + hnotebo.config_notebook() + hdbg.init_logger(verbosity=logging.INFO, use_exec_path=False) + hnotebo.set_logger_to_print(notebook_log) + configured_log = _LOG + hnotebo.set_logger_to_print(configured_log) + _LOG = configured_log + + def _mask(value: str | None) -> str: + """ + Mask a secret value for notebook display. + + :param value: value to mask + :return: masked value + """ if not value: return "" if len(value) <= 6: @@ -16,26 +55,124 @@ def _mask(value: str | None) -> str: return f"{value[:3]}...{value[-2:]}" -# ######################################################################### +def log_environment(env_path: str, model_id: str) -> None: + """ + Log notebook environment settings. + + :param env_path: dotenv file path + :param model_id: configured model identifier + """ + _LOG.info("dotenv path: %s", env_path or "") + _LOG.info("PYDANTIC_AI_MODEL: %s", model_id) + _LOG.info("OPENAI_API_KEY: %s", _mask(os.getenv("OPENAI_API_KEY"))) + + +# ############################################################################# # Code for tools and dependencies. -# ######################################################################### +# ############################################################################# def get_weather(city: str) -> str: - return f"The weather in {city} is sunny." + """ + Get deterministic demo weather for a city. + + :param city: city name + :return: weather response + """ + weather = f"The weather in {city} is sunny." + return weather def company_name(ctx: RunContext[Any]) -> str: - return ctx.deps.company + """ + Get the configured company from an agent run context. + :param ctx: PydanticAI run context + :return: configured company name + """ + company = ctx.deps.company + return company -# ######################################################################### + +# ############################################################################# # Code for async execution and validation demos. -# ######################################################################### -async def run_agent(agent: Any) -> Any: - result = await agent.run("Tell me about Tokyo") - return result.output +# ############################################################################# +def load_example_documents() -> dict[str, str]: + """ + Load tutorial documents used by validator and retrieval demos. + + :return: mapping from document id to document text + """ + global _DOCUMENTS_CACHE + if _DOCUMENTS_CACHE is not None: + return _DOCUMENTS_CACHE + dataset_dir = Path(__file__).resolve().parent / "example_dataset" + documents = {} + for path in sorted(dataset_dir.glob("*.md")): + documents[path.stem] = path.read_text() + _DOCUMENTS_CACHE = documents + return documents + + +def get_available_document_ids() -> list[str]: + """ + Get sorted document ids from the example dataset. + + :return: sorted list of document ids + """ + document_ids = sorted(load_example_documents()) + return document_ids + + +def search_documents(query: str, max_results: int = 3) -> str: + """ + Search local tutorial documents and return snippets for citation. + + :param query: search query + :param max_results: maximum number of snippets to return + :return: formatted snippets with doc ids and quotes + """ + documents = load_example_documents() + query_terms = [term for term in query.lower().split() if len(term) > 2] + candidates = [] + for doc_id, content in documents.items(): + for line in content.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + line_l = line.lower() + score = sum(1 for term in query_terms if term in line_l) + if score == 0 and query_terms: + continue + candidates.append((score, doc_id, line)) + candidates.sort(key=lambda item: (-item[0], item[1], item[2])) + if not candidates: + return "No matching snippets found." + snippets = [] + for _, doc_id, line in candidates[:max_results]: + snippets.append(f"doc_id={doc_id} | quote={line}") + snippets_out = "\n".join(snippets) + return snippets_out + + +async def run_agent(agent: Any, *, prompt: str = "Tell me about Tokyo") -> Any: + """ + Run an agent asynchronously. + + :param agent: PydanticAI agent + :param prompt: prompt to send to the agent + :return: agent output + """ + result = await agent.run(prompt) + output = result.output + return output def validate_sources(result: Any) -> Any: + """ + Validate answer source references. + + :param result: model output to validate + :return: validated model output + """ answer_l = result.answer.lower() mentions_docs = any( token in answer_l for token in ["doc", "document", "according", "source"] @@ -44,18 +181,158 @@ def validate_sources(result: Any) -> Any: raise ModelRetry("Answer references documents but sources are empty.") if len(result.sources) > 3: raise ModelRetry("Too many sources. Maximum allowed is 3.") - seen = set() - for s in result.sources: - key = (s.doc_id, s.quote) + seen: set[tuple[str, str]] = set() + for source in result.sources: + key = (source.doc_id, source.quote) if key in seen: raise ModelRetry("Duplicate sources found.") seen.add(key) return result -async def run_validator_example(validator_agent: Any) -> None: - result = await validator_agent.run( - "Explain something using documents and cite sources." - ) - print("\nValidated output:\n") - print(result.output) +def validate_document_sources(result: Any) -> Any: + """ + Validate sources against local tutorial documents. + + :param result: model output to validate + :return: validated model output + """ + result = validate_sources(result) + documents = load_example_documents() + for source in result.sources: + if source.doc_id not in documents: + raise ModelRetry( + f"Unknown doc_id '{source.doc_id}'. Use ids from example_dataset." + ) + doc_text = " ".join(documents[source.doc_id].lower().split()) + quote_text = " ".join(source.quote.lower().split()) + if quote_text not in doc_text: + raise ModelRetry( + f"Quote not found in cited document '{source.doc_id}'." + ) + return result + + +def build_missing_sources_retry() -> ModelRetry: + """ + Build the retry exception used by the missing-sources demo. + + :return: retry exception + """ + retry = ModelRetry("Answer references documents but sources are empty.") + return retry + + +async def run_validator_example( + validator_agent: Any, + *, + prompt: str = "Use local documents to explain Atlas billing plans and cite sources.", +) -> Any: + """ + Run the result validator example. + + :param validator_agent: configured validator agent + :return: validated output + """ + result = await validator_agent.run(prompt) + output = result.output + return output + + +# ############################################################################# +# Code for advanced API demos. +# ############################################################################# +async def run_streaming_demo(stream_agent: Any) -> Any: + """ + Run a streaming demo and log streamed text. + + :param stream_agent: configured streaming agent + :return: final streaming result or non-streamed result + """ + if not hasattr(stream_agent, "run_stream"): + _LOG.info("Streaming API not available; falling back to run().") + result = await stream_agent.run("What are unit tests?") + return result + async with stream_agent.run_stream("What are unit tests?") as stream: + stream_text = stream.stream_text + parameters = inspect.signature(stream_text).parameters + if "delta" in parameters: + text_stream = stream_text(delta=True) + else: + text_stream = stream_text() + chunks = [] + async for chunk in text_stream: + chunks.append(chunk) + if hasattr(stream, "get_final_result"): + result = await stream.get_final_result() + else: + result = "".join(chunks) + _LOG.info("Streaming output:\n%s", "".join(chunks)) + return result + + +def _get_openai_model_class() -> Any | None: + """ + Get the available explicit OpenAI model class. + + :return: model class, or None if unavailable + """ + if importlib.util.find_spec("pydantic_ai") is None: + return None + if importlib.util.find_spec("pydantic_ai.models.openai") is None: + return None + module = importlib.import_module("pydantic_ai.models.openai") + for class_name in ("OpenAIModel", "OpenAIChatModel"): + if hasattr(module, class_name): + model_class = getattr(module, class_name) + return model_class + return None + + +def build_explicit_openai_model(model_id: str) -> Any | None: + """ + Build an explicit OpenAI model object when the installed API supports it. + + :param model_id: configured model identifier + :return: explicit model object, or None + """ + model_class = _get_openai_model_class() + if model_class is None: + return None + hdbg.dassert_isinstance(model_id, str) + hdbg.dassert_ne(model_id, "", "Model id cannot be empty") + model_name = model_id.removeprefix("openai:") + _LOG.info("Using OpenAI model with model_name='%s'.", model_name) + signature = inspect.signature(model_class) + parameters = signature.parameters + base_kwargs = { + "api_key": os.getenv("OPENAI_API_KEY"), + "base_url": os.getenv("OPENAI_BASE_URL"), + } + args = [] + kwargs = {} + if "model_name" in parameters: + kwargs["model_name"] = model_name + elif "model" in parameters: + kwargs["model"] = model_name + else: + args.append(model_name) + for key, value in base_kwargs.items(): + if key in parameters: + kwargs[key] = value + model = model_class(*args, **kwargs) + return model + + +def get_settings_classes() -> tuple[Any, Any]: + """ + Get ModelSettings and UsageLimits classes for the installed version. + + :return: ModelSettings and UsageLimits classes + """ + module = importlib.import_module("pydantic_ai") + if hasattr(module, "ModelSettings") and hasattr(module, "UsageLimits"): + return module.ModelSettings, module.UsageLimits + models_module = importlib.import_module("pydantic_ai.models") + usage_module = importlib.import_module("pydantic_ai.usage") + return models_module.ModelSettings, usage_module.UsageLimits diff --git a/tutorials/tutorial_pydanticAI/pydanticai_example_utils.py b/tutorials/tutorial_pydanticAI/pydanticai_example_utils.py index 1a0c3bfc9..f86a010ed 100644 --- a/tutorials/tutorial_pydanticAI/pydanticai_example_utils.py +++ b/tutorials/tutorial_pydanticAI/pydanticai_example_utils.py @@ -1,13 +1,17 @@ """Utility functions for tutorials/tutorial_pydanticAI/pydanticai.example notebook.""" +import logging import hashlib import math import re from pathlib import Path from typing import Any +import helpers.hdbg as hdbg +import helpers.hnotebook as hnotebo from pydantic_ai import ModelRetry +_LOG = logging.getLogger(__name__) # ######################################################################### # Code for chunking and embeddings. @@ -15,6 +19,18 @@ _DIM = 256 +def init_logger(notebook_log: logging.Logger) -> None: + global _LOG + hnotebo.config_notebook() + hdbg.init_logger(verbosity=logging.INFO, use_exec_path=False) + # Init notebook logging. + hnotebo.set_logger_to_print(notebook_log) + # Init utils logging. + configured_log = _LOG or logging.getLogger(__name__) + hnotebo.set_logger_to_print(configured_log) + _LOG = configured_log + + def _stable_index(token: str, dim: int = _DIM) -> int: h = hashlib.md5(token.encode("utf-8")).digest() return int.from_bytes(h[:4], "little") % dim diff --git a/tutorials/tutorial_pydanticAI/requirements.txt b/tutorials/tutorial_pydanticAI/requirements.txt index 0657ec273..757bccfb3 100644 --- a/tutorials/tutorial_pydanticAI/requirements.txt +++ b/tutorials/tutorial_pydanticAI/requirements.txt @@ -2,9 +2,8 @@ ipykernel==6.30.1 jupyterlab==4.4.6 nbclient==0.10.2 nbformat==5.10.4 -python-dotenv==1.1.1 -typing_extensions==4.14.1 -nest_asyncio +python-dotenv +pydantic-ai matplotlib numpy pandas diff --git a/tutorials/tutorial_pydanticAI/test/test_pydanticai_API_utils.py b/tutorials/tutorial_pydanticAI/test/test_pydanticai_API_utils.py new file mode 100644 index 000000000..c90baf2bf --- /dev/null +++ b/tutorials/tutorial_pydanticAI/test/test_pydanticai_API_utils.py @@ -0,0 +1,1265 @@ +""" +Test utility functions for tutorials/tutorial_pydanticAI/pydanticai.API. +""" + +import asyncio +import importlib.util +import logging +import sys +import types +from types import SimpleNamespace +from unittest import mock + +import helpers.hunit_test as hunitest + +if importlib.util.find_spec("pydantic_ai") is None: + + class ModelRetry(Exception): + """ + Minimal stub for pydantic_ai.ModelRetry. + """ + + class RunContext: + """ + Minimal stub for pydantic_ai.RunContext. + """ + + def __class_getitem__(cls, item: object) -> type["RunContext"]: + """ + Support type annotations that use RunContext[Any]. + + :param item: type argument + :return: RunContext class + """ + return cls + + pydantic_ai_stub = types.ModuleType("pydantic_ai") + pydantic_ai_stub.ModelRetry = ModelRetry + pydantic_ai_stub.RunContext = RunContext + sys.modules["pydantic_ai"] = pydantic_ai_stub + +import pydanticai_API_utils as put +from pydantic_ai import ModelRetry + +_LOG = logging.getLogger(__name__) + + +# ############################################################################# +# Test_mask +# ############################################################################# + + +class Test__mask(hunitest.TestCase): + """ + Test secret masking for notebook environment output. + """ + + def helper(self, value: str | None, expected: str) -> None: + """ + Test helper for `_mask()`. + + :param value: value to mask + :param expected: expected masked value + """ + # Run test. + actual = put._mask(value) + # Check outputs. + self.assert_equal(actual, expected) + + def test1(self) -> None: + """ + Test masking a missing value. + """ + # Prepare inputs. + value = None + # Prepare outputs. + expected = "" + # Run test. + self.helper(value, expected) + + def test2(self) -> None: + """ + Test masking an empty value. + """ + # Prepare inputs. + value = "" + # Prepare outputs. + expected = "" + # Run test. + self.helper(value, expected) + + def test3(self) -> None: + """ + Test masking a short value. + """ + # Prepare inputs. + value = "secret" + # Prepare outputs. + expected = "******" + # Run test. + self.helper(value, expected) + + def test4(self) -> None: + """ + Test masking a normal secret value. + """ + # Prepare inputs. + value = "sk-1234567890" + # Prepare outputs. + expected = "sk-...90" + # Run test. + self.helper(value, expected) + + +# ############################################################################# +# Test_init_logger +# ############################################################################# + + +class Test_init_logger(hunitest.TestCase): + """ + Test notebook logger initialization. + """ + + def test1(self) -> None: + """ + Test that notebook logging helpers are configured. + """ + # Prepare inputs. + notebook_log = logging.getLogger("test_notebook") + utils_log = logging.getLogger("test_utils") + # Run test. + with mock.patch.object(put, "_LOG", utils_log), mock.patch.object( + put.hnotebo, "config_notebook" + ) as mock_config, mock.patch.object( + put.hdbg, "init_logger" + ) as mock_init_logger, mock.patch.object( + put.hnotebo, + "set_logger_to_print", + return_value=None, + ) as mock_set_logger: + put.init_logger(notebook_log) + # Check outputs. + mock_config.assert_called_once() + mock_init_logger.assert_called_once_with( + verbosity=logging.INFO, use_exec_path=False + ) + self.assertEqual(mock_set_logger.call_count, 2) + + def test2(self) -> None: + """ + Test that logger configuration uses the notebook logger and module logger. + """ + # Prepare inputs. + notebook_log = logging.getLogger("test_notebook") + utils_log = logging.getLogger("test_utils") + # Run test. + with mock.patch.object(put, "_LOG", utils_log), mock.patch.object( + put.hnotebo, "config_notebook" + ), mock.patch.object( + put.hdbg, "init_logger" + ), mock.patch.object( + put.hnotebo, + "set_logger_to_print", + return_value=None, + ) as mock_set_logger: + put.init_logger(notebook_log) + # Check outputs. + self.assertEqual( + mock_set_logger.call_args_list, + [mock.call(notebook_log), mock.call(utils_log)], + ) + + +# ############################################################################# +# Test_log_environment +# ############################################################################# + + +class Test_log_environment(hunitest.TestCase): + """ + Test environment logging for notebook setup. + """ + + def test1(self) -> None: + """ + Test logging configured environment values. + """ + # Prepare inputs. + env_path = "/tmp/.env" + model_id = "openai:gpt-5-nano" + openai_api_key = "sk-1234567890" + # Prepare outputs. + expected = [ + mock.call("dotenv path: %s", env_path), + mock.call("PYDANTIC_AI_MODEL: %s", model_id), + mock.call("OPENAI_API_KEY: %s", "sk-...90"), + ] + # Run test. + with mock.patch.object(put._LOG, "info") as mock_log, mock.patch.dict( + put.os.environ, {"OPENAI_API_KEY": openai_api_key}, clear=False + ): + put.log_environment(env_path, model_id) + # Check outputs. + self.assertEqual(mock_log.call_args_list, expected) + + def test2(self) -> None: + """ + Test logging missing environment values. + """ + # Prepare inputs. + env_path = "" + model_id = "" + # Prepare outputs. + expected = [ + mock.call("dotenv path: %s", ""), + mock.call("PYDANTIC_AI_MODEL: %s", ""), + mock.call("OPENAI_API_KEY: %s", ""), + ] + # Run test. + with mock.patch.object(put._LOG, "info") as mock_log, mock.patch.dict( + put.os.environ, {}, clear=True + ): + put.log_environment(env_path, model_id) + # Check outputs. + self.assertEqual(mock_log.call_args_list, expected) + + +# ############################################################################# +# Test_get_weather +# ############################################################################# + + +class Test_get_weather(hunitest.TestCase): + """ + Test deterministic weather output. + """ + + def helper(self, city: str, expected: str) -> None: + """ + Test helper for `get_weather()`. + + :param city: city name + :param expected: expected weather response + """ + # Run test. + actual = put.get_weather(city) + # Check outputs. + self.assert_equal(actual, expected) + + def test1(self) -> None: + """ + Test weather output for a normal city. + """ + # Prepare inputs. + city = "Tokyo" + # Prepare outputs. + expected = "The weather in Tokyo is sunny." + # Run test. + self.helper(city, expected) + + def test2(self) -> None: + """ + Test weather output for an empty city. + """ + # Prepare inputs. + city = "" + # Prepare outputs. + expected = "The weather in is sunny." + # Run test. + self.helper(city, expected) + + +# ############################################################################# +# Test_build_missing_sources_retry +# ############################################################################# + + +class Test_build_missing_sources_retry(hunitest.TestCase): + """ + Test construction of the missing-sources retry exception. + """ + + def test1(self) -> None: + """ + Test that the helper builds a ModelRetry instance. + """ + # Prepare outputs. + expected = "Answer references documents but sources are empty." + # Run test. + actual = put.build_missing_sources_retry() + # Check outputs. + self.assertIsInstance(actual, ModelRetry) + self.assert_equal(str(actual), expected) + + +# ############################################################################# +# Test_validate_sources +# ############################################################################# + + +class Test_validate_sources(hunitest.TestCase): + """ + Test answer source validation. + """ + + def helper(self, result: SimpleNamespace, expected: str | SimpleNamespace) -> None: + """ + Test helper for `validate_sources()`. + + :param result: validator input + :param expected: expected output or retry message + """ + # Run test. + if isinstance(expected, str): + with self.assertRaises(ModelRetry) as cm: + put.validate_sources(result) + actual = str(cm.exception) + # Check outputs. + self.assert_equal(actual, expected) + else: + actual = put.validate_sources(result) + # Check outputs. + self.assertEqual(actual, expected) + + def test1(self) -> None: + """ + Test an answer with no document claim and no sources. + """ + # Prepare inputs. + result = self._build_result("This answer is standalone.", []) + # Run test. + self.helper(result, result) + + def test2(self) -> None: + """ + Test an answer with document references and sources. + """ + # Prepare inputs. + sources = [self._build_source("doc1", "quoted text")] + result = self._build_result("According to the document.", sources) + # Run test. + self.helper(result, result) + + def test3(self) -> None: + """ + Test that duplicate sources raise ModelRetry. + """ + # Prepare inputs. + sources = [ + self._build_source("doc1", "quoted text"), + self._build_source("doc1", "quoted text"), + ] + result = self._build_result("Standalone answer.", sources) + # Prepare outputs. + expected = "Duplicate sources found." + # Run test. + self.helper(result, expected) + + def test4(self) -> None: + """ + Test that too many sources raise ModelRetry. + """ + # Prepare inputs. + sources = [ + self._build_source("doc1", "quote1"), + self._build_source("doc2", "quote2"), + self._build_source("doc3", "quote3"), + self._build_source("doc4", "quote4"), + ] + result = self._build_result("Standalone answer.", sources) + # Prepare outputs. + expected = "Too many sources. Maximum allowed is 3." + # Run test. + self.helper(result, expected) + + def test5(self) -> None: + """ + Test that document claims without sources raise ModelRetry. + """ + # Prepare inputs. + result = self._build_result("According to the documents.", []) + # Prepare outputs. + expected = "Answer references documents but sources are empty." + # Run test. + self.helper(result, expected) + + @staticmethod + def _build_result( + answer: str, sources: list[SimpleNamespace] + ) -> SimpleNamespace: + """ + Build a validator input object. + + :param answer: answer text + :param sources: source references + :return: validator input + """ + result = SimpleNamespace(answer=answer, sources=sources) + return result + + @staticmethod + def _build_source(doc_id: str, quote: str) -> SimpleNamespace: + """ + Build a source reference object. + + :param doc_id: document identifier + :param quote: source quote + :return: source reference + """ + source = SimpleNamespace(doc_id=doc_id, quote=quote) + return source + + +# ############################################################################# +# Test_company_name +# ############################################################################# + + +class Test_company_name(hunitest.TestCase): + """ + Test dependency access for the company-name tool. + """ + + def test1(self) -> None: + """ + Test reading the company from a run context. + """ + # Prepare inputs. + ctx = SimpleNamespace(deps=SimpleNamespace(company="OpenAI")) + # Prepare outputs. + expected = "OpenAI" + # Run test. + actual = put.company_name(ctx) + # Check outputs. + self.assert_equal(actual, expected) + + def test2(self) -> None: + """ + Test reading an empty company from a run context. + """ + # Prepare inputs. + ctx = SimpleNamespace(deps=SimpleNamespace(company="")) + # Prepare outputs. + expected = "" + # Run test. + actual = put.company_name(ctx) + # Check outputs. + self.assert_equal(actual, expected) + + +# ############################################################################# +# Test_load_example_documents +# ############################################################################# + + +class Test_load_example_documents(hunitest.TestCase): + """ + Test loading local example documents. + """ + + def test1(self) -> None: + """ + Test that tutorial documents are loaded. + """ + # Prepare inputs. + put._DOCUMENTS_CACHE = None + # Run test. + actual = put.load_example_documents() + # Check outputs. + self.assertIn("billing", actual) + self.assertIn("Starter: $20 per month", actual["billing"]) + + def test2(self) -> None: + """ + Test that the cached documents are reused. + """ + # Prepare inputs. + expected = {"cached": "document"} + put._DOCUMENTS_CACHE = expected + # Run test. + actual = put.load_example_documents() + # Check outputs. + self.assertEqual(actual, expected) + + +# ############################################################################# +# Test_get_available_document_ids +# ############################################################################# + + +class Test_get_available_document_ids(hunitest.TestCase): + """ + Test document-id discovery. + """ + + def test1(self) -> None: + """ + Test that document ids are returned in sorted order. + """ + # Prepare outputs. + expected = sorted(put.load_example_documents()) + # Run test. + actual = put.get_available_document_ids() + # Check outputs. + self.assert_equal(str(actual), str(expected)) + + def test2(self) -> None: + """ + Test that an empty document mapping returns no document ids. + """ + # Prepare outputs. + expected = [] + # Run test. + with mock.patch.object(put, "load_example_documents", return_value={}): + actual = put.get_available_document_ids() + # Check outputs. + self.assert_equal(str(actual), str(expected)) + + +# ############################################################################# +# Test_search_documents +# ############################################################################# + + +class Test_search_documents(hunitest.TestCase): + """ + Test local document search snippets. + """ + + def helper(self, query: str, max_results: int) -> str: + """ + Test helper for `search_documents()`. + + :param query: search query + :param max_results: maximum number of snippets + :return: search output + """ + # Prepare inputs. + put._DOCUMENTS_CACHE = None + # Run test. + actual = put.search_documents(query, max_results=max_results) + return actual + + def test1(self) -> None: + """ + Test a search query with matching snippets. + """ + # Prepare inputs. + query = "billing starter" + max_results = 1 + # Run test. + actual = self.helper(query, max_results) + # Check outputs. + self.assertIn("doc_id=billing", actual) + self.assertIn("Starter", actual) + + def test2(self) -> None: + """ + Test a search query with no matching snippets. + """ + # Prepare inputs. + query = "zzzzzz" + max_results = 3 + # Prepare outputs. + expected = "No matching snippets found." + # Run test. + actual = self.helper(query, max_results) + # Check outputs. + self.assert_equal(actual, expected) + + def test3(self) -> None: + """ + Test that the result count respects the requested limit. + """ + # Prepare inputs. + query = "" + max_results = 2 + # Run test. + actual = self.helper(query, max_results) + # Check outputs. + self.assertEqual(len(actual.splitlines()), 2) + + +# ############################################################################# +# Test_validate_document_sources +# ############################################################################# + + +class Test_validate_document_sources(hunitest.TestCase): + """ + Test source validation against local documents. + """ + + def helper(self, result: SimpleNamespace, expected: str | SimpleNamespace) -> None: + """ + Test helper for `validate_document_sources()`. + + :param result: validator input + :param expected: expected output or retry message + """ + # Run test. + if isinstance(expected, str): + with self.assertRaises(ModelRetry) as cm: + put.validate_document_sources(result) + actual = str(cm.exception) + # Check outputs. + self.assert_equal(actual, expected) + else: + actual = put.validate_document_sources(result) + # Check outputs. + self.assertEqual(actual, expected) + + def test1(self) -> None: + """ + Test a valid source quote. + """ + # Prepare inputs. + sources = [ + self._build_source( + "billing", + "Starter: $20 per month, 5 data sources, email support.", + ) + ] + result = self._build_result("According to the documents.", sources) + # Run test. + self.helper(result, result) + + def test2(self) -> None: + """ + Test that an unknown document id raises ModelRetry. + """ + # Prepare inputs. + sources = [self._build_source("missing", "quoted text")] + result = self._build_result("According to the documents.", sources) + # Prepare outputs. + expected = "Unknown doc_id 'missing'. Use ids from example_dataset." + # Run test. + self.helper(result, expected) + + def test3(self) -> None: + """ + Test that a quote mismatch raises ModelRetry. + """ + # Prepare inputs. + sources = [self._build_source("billing", "not present in billing")] + result = self._build_result("According to the documents.", sources) + # Prepare outputs. + expected = "Quote not found in cited document 'billing'." + # Run test. + self.helper(result, expected) + + @staticmethod + def _build_result( + answer: str, sources: list[SimpleNamespace] + ) -> SimpleNamespace: + """ + Build a validator input object. + + :param answer: answer text + :param sources: source references + :return: validator input + """ + result = SimpleNamespace(answer=answer, sources=sources) + return result + + @staticmethod + def _build_source(doc_id: str, quote: str) -> SimpleNamespace: + """ + Build a source reference object. + + :param doc_id: document identifier + :param quote: source quote + :return: source reference + """ + source = SimpleNamespace(doc_id=doc_id, quote=quote) + return source + + +# ############################################################################# +# Test_run_agent +# ############################################################################# + + +class Test_run_agent(hunitest.TestCase): + """ + Test async agent helper execution. + """ + + class _Agent: + """ + Minimal async agent used by tests. + """ + + async def run(self, prompt: str) -> SimpleNamespace: + """ + Return a fake run result. + + :param prompt: prompt sent to the agent + :return: fake run result + """ + result = SimpleNamespace(output=f"answer: {prompt}") + return result + + def helper(self, prompt: str, expected: str) -> None: + """ + Test helper for `run_agent()`. + + :param prompt: prompt sent to the agent + :param expected: expected output + """ + # Prepare inputs. + agent = self._Agent() + # Run test. + actual = asyncio.run(put.run_agent(agent, prompt=prompt)) + # Check outputs. + self.assert_equal(actual, expected) + + def test1(self) -> None: + """ + Test running an async agent. + """ + # Prepare inputs. + prompt = "hello" + # Prepare outputs. + expected = "answer: hello" + # Run test. + self.helper(prompt, expected) + + def test2(self) -> None: + """ + Test running an async agent with the default prompt. + """ + # Prepare inputs. + prompt = "Tell me about Tokyo" + # Prepare outputs. + expected = "answer: Tell me about Tokyo" + # Run test. + self.helper(prompt, expected) + + +# ############################################################################# +# Test_run_validator_example +# ############################################################################# + + +class Test_run_validator_example(hunitest.TestCase): + """ + Test validator example helper execution. + """ + + class _Agent: + """ + Minimal async validator agent used by tests. + """ + + async def run(self, prompt: str) -> SimpleNamespace: + """ + Return a fake validator run result. + + :param prompt: prompt sent to the agent + :return: fake run result + """ + result = SimpleNamespace(output={"prompt": prompt}) + return result + + def helper(self, prompt: str, expected: dict[str, str]) -> None: + """ + Test helper for `run_validator_example()`. + + :param prompt: prompt sent to the validator agent + :param expected: expected output + """ + # Prepare inputs. + agent = self._Agent() + # Run test. + actual = asyncio.run(put.run_validator_example(agent, prompt=prompt)) + # Check outputs. + self.assert_equal(str(actual), str(expected)) + + def test1(self) -> None: + """ + Test running the validator example helper. + """ + # Prepare inputs. + prompt = "cite docs" + # Prepare outputs. + expected = {"prompt": prompt} + # Run test. + self.helper(prompt, expected) + + def test2(self) -> None: + """ + Test running the validator example helper with the default prompt. + """ + # Prepare inputs. + prompt = "Use local documents to explain Atlas billing plans and cite sources." + # Prepare outputs. + expected = {"prompt": prompt} + # Run test. + self.helper(prompt, expected) + + +# ############################################################################# +# Test_run_streaming_demo +# ############################################################################# + + +class Test_run_streaming_demo(hunitest.TestCase): + """ + Test streaming helper fallback behavior. + """ + + class _Agent: + """ + Minimal agent without streaming support. + """ + + async def run(self, prompt: str) -> SimpleNamespace: + """ + Return a fake fallback run result. + + :param prompt: prompt sent to the agent + :return: fake run result + """ + result = SimpleNamespace(output=f"fallback: {prompt}") + return result + + class _StreamingText: + """ + Minimal async iterator for stream chunks. + """ + + def __init__(self, chunks: list[str]) -> None: + self._chunks = chunks + self._index = 0 + + def __aiter__(self) -> "Test_run_streaming_demo._StreamingText": + return self + + async def __anext__(self) -> str: + if self._index >= len(self._chunks): + raise StopAsyncIteration + value = self._chunks[self._index] + self._index += 1 + return value + + class _Stream: + """ + Minimal async stream context manager. + """ + + def __init__(self, chunks: list[str], result: str) -> None: + self._chunks = chunks + self._result = result + + async def __aenter__(self) -> "Test_run_streaming_demo._Stream": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + def stream_text(self) -> "Test_run_streaming_demo._StreamingText": + return Test_run_streaming_demo._StreamingText(self._chunks) + + async def get_final_result(self) -> str: + return self._result + + class _StreamWithDelta: + """ + Minimal async stream context manager with delta support. + """ + + def __init__(self, chunks: list[str], result: str) -> None: + self._chunks = chunks + self._result = result + + async def __aenter__(self) -> "Test_run_streaming_demo._StreamWithDelta": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + def stream_text( + self, delta: bool = False + ) -> "Test_run_streaming_demo._StreamingText": + if not delta: + raise AssertionError("Expected delta=True.") + return Test_run_streaming_demo._StreamingText(self._chunks) + + async def get_final_result(self) -> str: + return self._result + + class _StreamWithoutFinalResult: + """ + Minimal async stream context manager without final-result support. + """ + + def __init__(self, chunks: list[str]) -> None: + self._chunks = chunks + + async def __aenter__( + self, + ) -> "Test_run_streaming_demo._StreamWithoutFinalResult": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + def stream_text(self) -> "Test_run_streaming_demo._StreamingText": + return Test_run_streaming_demo._StreamingText(self._chunks) + + class _StreamingAgent: + """ + Minimal agent with streaming support. + """ + + def run_stream(self, prompt: str) -> "Test_run_streaming_demo._Stream": + return Test_run_streaming_demo._Stream( + ["unit ", "tests"], "unit tests" + ) + + class _StreamingAgentWithDelta: + """ + Minimal agent with delta-based streaming support. + """ + + def run_stream( + self, prompt: str + ) -> "Test_run_streaming_demo._StreamWithDelta": + return Test_run_streaming_demo._StreamWithDelta( + ["unit ", "tests"], "unit tests" + ) + + class _StreamingAgentWithoutFinalResult: + """ + Minimal agent without final-result streaming support. + """ + + def run_stream( + self, prompt: str + ) -> "Test_run_streaming_demo._StreamWithoutFinalResult": + return Test_run_streaming_demo._StreamWithoutFinalResult( + ["unit ", "tests"] + ) + + def test1(self) -> None: + """ + Test fallback execution when streaming is unavailable. + """ + # Prepare inputs. + agent = self._Agent() + # Run test. + actual = asyncio.run(put.run_streaming_demo(agent)) + # Check outputs. + self.assert_equal(actual.output, "fallback: What are unit tests?") + + def test2(self) -> None: + """ + Test streaming execution when the streaming API is available. + """ + # Prepare inputs. + agent = self._StreamingAgent() + # Run test. + actual = asyncio.run(put.run_streaming_demo(agent)) + # Check outputs. + self.assert_equal(actual, "unit tests") + + def test3(self) -> None: + """ + Test streaming execution when `stream_text(delta=True)` is supported. + """ + # Prepare inputs. + agent = self._StreamingAgentWithDelta() + # Run test. + actual = asyncio.run(put.run_streaming_demo(agent)) + # Check outputs. + self.assert_equal(actual, "unit tests") + + def test4(self) -> None: + """ + Test streaming execution without `get_final_result()`. + """ + # Prepare inputs. + agent = self._StreamingAgentWithoutFinalResult() + # Run test. + actual = asyncio.run(put.run_streaming_demo(agent)) + # Check outputs. + self.assert_equal(actual, "unit tests") + + +# ############################################################################# +# Test_get_openai_model_class +# ############################################################################# + + +class Test__get_openai_model_class(hunitest.TestCase): + """ + Test OpenAI model class discovery. + """ + + def test1(self) -> None: + """ + Test missing OpenAI model module. + """ + # Prepare inputs. + # Run test. + with mock.patch.object( + put.importlib.util, "find_spec", return_value=None + ): + actual = put._get_openai_model_class() + # Check outputs. + self.assertIsNone(actual) + + def test2(self) -> None: + """ + Test discovery of the explicit model class from the OpenAI module. + """ + # Prepare inputs. + openai_module = SimpleNamespace(OpenAIModel=object) + # Run test. + with mock.patch.object( + put.importlib.util, + "find_spec", + side_effect=[object(), object()], + ), mock.patch.object( + put.importlib, "import_module", return_value=openai_module + ): + actual = put._get_openai_model_class() + # Check outputs. + self.assertEqual(actual, object) + + def test3(self) -> None: + """ + Test missing OpenAI submodule. + """ + # Prepare inputs. + # Run test. + with mock.patch.object( + put.importlib.util, + "find_spec", + side_effect=[object(), None], + ): + actual = put._get_openai_model_class() + # Check outputs. + self.assertIsNone(actual) + + def test4(self) -> None: + """ + Test discovery of the chat-model class from the OpenAI module. + """ + # Prepare inputs. + openai_chat_model = object() + openai_module = SimpleNamespace(OpenAIChatModel=openai_chat_model) + # Run test. + with mock.patch.object( + put.importlib.util, + "find_spec", + side_effect=[object(), object()], + ), mock.patch.object( + put.importlib, "import_module", return_value=openai_module + ): + actual = put._get_openai_model_class() + # Check outputs. + self.assertEqual(actual, openai_chat_model) + + +# ############################################################################# +# Test_build_explicit_openai_model +# ############################################################################# + + +class Test_build_explicit_openai_model(hunitest.TestCase): + """ + Test explicit OpenAI model construction. + """ + + class _Model: + """ + Fake explicit OpenAI model class. + """ + + def __init__( + self, model_name: str, api_key: str | None = None, base_url: str | None = None + ) -> None: + self.model_name = model_name + self.api_key = api_key + self.base_url = base_url + + class _ModelWithModelArg: + """ + Fake explicit OpenAI model class with a `model` kwarg. + """ + + def __init__( + self, model: str, api_key: str | None = None, base_url: str | None = None + ) -> None: + self.model = model + self.api_key = api_key + self.base_url = base_url + + class _ModelWithPositionalArg: + """ + Fake explicit OpenAI model class with a positional model arg. + """ + + def __init__(self, model_name: str) -> None: + self.model_name = model_name + + def test1(self) -> None: + """ + Test missing model class fallback. + """ + # Prepare inputs. + model_id = "openai:gpt-5-nano" + # Run test. + with mock.patch.object( + put, "_get_openai_model_class", return_value=None + ): + actual = put.build_explicit_openai_model(model_id) + # Check outputs. + self.assertIsNone(actual) + + def test2(self) -> None: + """ + Test explicit model construction with environment-backed kwargs. + """ + # Prepare inputs. + model_id = "openai:gpt-5-nano" + # Run test. + with mock.patch.object( + put, "_get_openai_model_class", return_value=self._Model + ), mock.patch.dict( + put.os.environ, + {"OPENAI_API_KEY": "token", "OPENAI_BASE_URL": "https://example.com"}, + clear=False, + ): + actual = put.build_explicit_openai_model(model_id) + # Check outputs. + self.assertEqual(actual.model_name, "gpt-5-nano") + self.assertEqual(actual.api_key, "token") + self.assertEqual(actual.base_url, "https://example.com") + + def test3(self) -> None: + """ + Test explicit model construction with a `model` kwarg. + """ + # Prepare inputs. + model_id = "openai:gpt-5-nano" + # Run test. + with mock.patch.object( + put, "_get_openai_model_class", return_value=self._ModelWithModelArg + ), mock.patch.dict(put.os.environ, {}, clear=False): + actual = put.build_explicit_openai_model(model_id) + # Check outputs. + self.assertEqual(actual.model, "gpt-5-nano") + + def test4(self) -> None: + """ + Test explicit model construction with a positional model arg. + """ + # Prepare inputs. + model_id = "openai:gpt-5-nano" + # Run test. + with mock.patch.object( + put, + "_get_openai_model_class", + return_value=self._ModelWithPositionalArg, + ): + actual = put.build_explicit_openai_model(model_id) + # Check outputs. + self.assertEqual(actual.model_name, "gpt-5-nano") + + def test5(self) -> None: + """ + Test that an empty model id raises an assertion. + """ + # Prepare inputs. + model_id = "" + # Run test and check output. + with self.assertRaises(AssertionError): + with mock.patch.object( + put, "_get_openai_model_class", return_value=self._Model + ): + put.build_explicit_openai_model(model_id) + + +# ############################################################################# +# Test_get_settings_classes +# ############################################################################# + + +class Test_get_settings_classes(hunitest.TestCase): + """ + Test settings class discovery. + """ + + class _ModelSettings: + """ + Fake model settings class. + """ + + class _UsageLimits: + """ + Fake usage limits class. + """ + + class _ModelsModule: + """ + Fake models module. + """ + + ModelSettings = object() + + class _UsageModule: + """ + Fake usage module. + """ + + UsageLimits = object() + + def test1(self) -> None: + """ + Test direct class discovery from the pydantic_ai module. + """ + # Prepare inputs. + pydantic_ai_module = sys.modules["pydantic_ai"] + pydantic_ai_module.ModelSettings = self._ModelSettings + pydantic_ai_module.UsageLimits = self._UsageLimits + # Prepare outputs. + expected = (self._ModelSettings, self._UsageLimits) + # Run test. + actual = put.get_settings_classes() + # Check outputs. + self.assert_equal(str(actual), str(expected)) + del pydantic_ai_module.ModelSettings + del pydantic_ai_module.UsageLimits + + def test2(self) -> None: + """ + Test fallback class discovery from submodules. + """ + # Prepare inputs. + pydantic_ai_module = sys.modules["pydantic_ai"] + if hasattr(pydantic_ai_module, "ModelSettings"): + del pydantic_ai_module.ModelSettings + if hasattr(pydantic_ai_module, "UsageLimits"): + del pydantic_ai_module.UsageLimits + # Prepare outputs. + expected = ( + self._ModelsModule.ModelSettings, + self._UsageModule.UsageLimits, + ) + # Run test. + with mock.patch.object( + put.importlib, + "import_module", + side_effect=[ + pydantic_ai_module, + self._ModelsModule, + self._UsageModule, + ], + ): + actual = put.get_settings_classes() + # Check outputs. + self.assert_equal(str(actual), str(expected)) diff --git a/tutorials/tutorial_pydanticAI/test/test_pydanticai_example_utils.py b/tutorials/tutorial_pydanticAI/test/test_pydanticai_example_utils.py new file mode 100644 index 000000000..7d835173d --- /dev/null +++ b/tutorials/tutorial_pydanticAI/test/test_pydanticai_example_utils.py @@ -0,0 +1,723 @@ +""" +Test utility functions for tutorials/tutorial_pydanticAI/pydanticai.example. +""" + +import asyncio +import importlib.util +import logging +import os +import pathlib +import sys +import types +from dataclasses import dataclass +from types import SimpleNamespace +from unittest import mock + +import helpers.hunit_test as hunitest + +if importlib.util.find_spec("pydantic_ai") is None: + + class ModelRetry(Exception): + """ + Minimal stub for pydantic_ai.ModelRetry. + """ + + pydantic_ai_stub = types.ModuleType("pydantic_ai") + pydantic_ai_stub.ModelRetry = ModelRetry + sys.modules["pydantic_ai"] = pydantic_ai_stub + +import pydanticai_example_utils as peu +from pydantic_ai import ModelRetry + +_LOG = logging.getLogger(__name__) + + +@dataclass +class _DocChunk: + """ + Minimal chunk object for tests. + """ + + doc_id: str + chunk_id: int + text: str + vector: list[float] + + +@dataclass +class _DocMatch: + """ + Minimal ranked match object for tests. + """ + + doc_id: str + chunk_id: int + score: float + text: str + + +# ############################################################################# +# Test_init_logger +# ############################################################################# + + +class Test_init_logger(hunitest.TestCase): + """ + Test notebook logger initialization. + """ + + def test1(self) -> None: + """ + Test initialization with an existing module logger. + """ + # Prepare inputs. + notebook_log = logging.getLogger("test_notebook") + module_log = logging.getLogger("test_module") + # Run test. + with mock.patch.object(peu, "_LOG", module_log), mock.patch.object( + peu.hnotebo, "config_notebook" + ) as mock_config, mock.patch.object( + peu.hdbg, "init_logger" + ) as mock_init_logger, mock.patch.object( + peu.hnotebo, "set_logger_to_print" + ) as mock_set_logger: + peu.init_logger(notebook_log) + # Check outputs. + mock_config.assert_called_once() + mock_init_logger.assert_called_once_with( + verbosity=logging.INFO, use_exec_path=False + ) + self.assertEqual( + mock_set_logger.call_args_list, + [mock.call(notebook_log), mock.call(module_log)], + ) + + def test2(self) -> None: + """ + Test initialization when the module logger must be recreated. + """ + # Prepare inputs. + notebook_log = logging.getLogger("test_notebook") + expected = logging.getLogger(peu.__name__) + # Run test. + with mock.patch.object(peu, "_LOG", None), mock.patch.object( + peu.hnotebo, "config_notebook" + ), mock.patch.object( + peu.hdbg, "init_logger" + ), mock.patch.object( + peu.hnotebo, "set_logger_to_print" + ) as mock_set_logger: + peu.init_logger(notebook_log) + # Check outputs. + self.assertEqual( + mock_set_logger.call_args_list, + [mock.call(notebook_log), mock.call(expected)], + ) + self.assertEqual(peu._LOG, expected) + + +# ############################################################################# +# Test__stable_index +# ############################################################################# + + +class Test__stable_index(hunitest.TestCase): + """ + Test deterministic token indexing. + """ + + def helper(self, token: str, dim: int) -> int: + """ + Test helper for `_stable_index()`. + + :param token: input token + :param dim: embedding dimension + :return: stable index + """ + # Run test. + actual = peu._stable_index(token, dim=dim) + # Check outputs. + self.assertEqual(actual < dim, True) + self.assertEqual(actual >= 0, True) + return actual + + def test1(self) -> None: + """ + Test that the same token maps deterministically. + """ + # Prepare inputs. + token = "atlas" + dim = 256 + # Run test. + actual1 = self.helper(token, dim) + actual2 = self.helper(token, dim) + # Check outputs. + self.assertEqual(actual1, actual2) + + def test2(self) -> None: + """ + Test that an empty token still maps inside bounds. + """ + # Prepare inputs. + token = "" + dim = 8 + # Run test. + actual = self.helper(token, dim) + # Check outputs. + self.assertEqual(actual < dim, True) + + +# ############################################################################# +# Test_embed +# ############################################################################# + + +class Test_embed(hunitest.TestCase): + """ + Test deterministic text embeddings. + """ + + def test1(self) -> None: + """ + Test embedding an empty string. + """ + # Prepare inputs. + text = "" + # Run test. + actual = peu.embed(text) + # Check outputs. + self.assertEqual(len(actual), 256) + self.assertEqual(sum(actual), 0.0) + + def test2(self) -> None: + """ + Test embedding normalization and token cleanup. + """ + # Prepare inputs. + text1 = "Atlas billing" + text2 = "atlas BILLING!!" + # Run test. + actual1 = peu.embed(text1) + actual2 = peu.embed(text2) + norm = sum(x * x for x in actual1) + # Check outputs. + self.assert_equal(str(actual1), str(actual2)) + self.assertEqual(round(norm, 6), 1.0) + + +# ############################################################################# +# Test_dot +# ############################################################################# + + +class Test_dot(hunitest.TestCase): + """ + Test vector dot products. + """ + + def helper( + self, left: list[float], right: list[float], expected: float + ) -> None: + """ + Test helper for `dot()`. + + :param left: left vector + :param right: right vector + :param expected: expected dot product + """ + # Run test. + actual = peu.dot(left, right) + # Check outputs. + self.assertEqual(actual, expected) + + def test1(self) -> None: + """ + Test a normal dot product. + """ + # Prepare inputs. + left = [1.0, 2.0, 3.0] + right = [4.0, 5.0, 6.0] + # Prepare outputs. + expected = 32.0 + # Run test. + self.helper(left, right, expected) + + def test2(self) -> None: + """ + Test the dot product of empty vectors. + """ + # Prepare inputs. + left = [] + right = [] + # Prepare outputs. + expected = 0 + # Run test. + self.helper(left, right, expected) + + +# ############################################################################# +# Test_chunk_docs +# ############################################################################# + + +class Test_chunk_docs(hunitest.TestCase): + """ + Test document chunking. + """ + + def test1(self) -> None: + """ + Test chunking a short document into one chunk. + """ + # Prepare inputs. + docs = [{"doc_id": "billing", "text": "invoice details"}] + # Run test. + actual = peu.chunk_docs(docs, _DocChunk, max_chars=100) + # Check outputs. + self.assertEqual(len(actual), 1) + self.assertEqual(actual[0].doc_id, "billing") + self.assertEqual(actual[0].chunk_id, 0) + self.assertEqual(actual[0].text, "invoice details") + + def test2(self) -> None: + """ + Test chunking a document into multiple parts. + """ + # Prepare inputs. + docs = [{"doc_id": "billing", "text": "abcdefgh"}] + # Run test. + actual = peu.chunk_docs(docs, _DocChunk, max_chars=3) + # Check outputs. + expected = ["abc", "def", "gh"] + self.assert_equal(str([chunk.text for chunk in actual]), str(expected)) + + +# ############################################################################# +# Test_search_chunks +# ############################################################################# + + +class Test_search_chunks(hunitest.TestCase): + """ + Test chunk ranking and truncation. + """ + + def test1(self) -> None: + """ + Test ranking chunks by query similarity. + """ + # Prepare inputs. + chunks = [ + _DocChunk("limits", 0, "storage limits", peu.embed("storage limits")), + _DocChunk("billing", 0, "invoice billing", peu.embed("invoice billing")), + ] + # Run test. + actual = peu.search_chunks( + chunks, + "invoice", + _DocMatch, + top_k=2, + ) + # Check outputs. + self.assertEqual(actual[0].doc_id, "billing") + self.assertEqual(len(actual), 2) + + def test2(self) -> None: + """ + Test limiting the number of ranked matches. + """ + # Prepare inputs. + chunks = [ + _DocChunk("a", 0, "billing", peu.embed("billing")), + _DocChunk("b", 0, "invoice", peu.embed("invoice")), + _DocChunk("c", 0, "support", peu.embed("support")), + ] + # Run test. + actual = peu.search_chunks(chunks, "invoice", _DocMatch, top_k=1) + # Check outputs. + self.assertEqual(len(actual), 1) + + +# ############################################################################# +# Test_search_docs +# ############################################################################# + + +class Test_search_docs(hunitest.TestCase): + """ + Test context-aware document search. + """ + + def test1(self) -> None: + """ + Test searching through chunks stored in the run context. + """ + # Prepare inputs. + chunks = [ + _DocChunk("billing", 0, "invoice download", peu.embed("invoice download")), + _DocChunk("security", 0, "enable 2fa", peu.embed("enable 2fa")), + ] + ctx = SimpleNamespace(deps=SimpleNamespace(chunks=chunks)) + # Run test. + actual = peu.search_docs(ctx, "invoice", doc_match_cls=_DocMatch) + # Check outputs. + self.assertEqual(actual[0].doc_id, "billing") + + +# ############################################################################# +# Test_enforce_sources +# ############################################################################# + + +class Test_enforce_sources(hunitest.TestCase): + """ + Test answer source validation. + """ + + def helper(self, result: SimpleNamespace, expected: str | SimpleNamespace) -> None: + """ + Test helper for `enforce_sources()`. + + :param result: validator input + :param expected: expected output or retry message + """ + # Run test. + if isinstance(expected, str): + with self.assertRaises(ModelRetry) as cm: + peu.enforce_sources(result) + actual = str(cm.exception) + # Check outputs. + self.assert_equal(actual, expected) + else: + actual = peu.enforce_sources(result) + # Check outputs. + self.assertEqual(actual, expected) + + def test1(self) -> None: + """ + Test a standalone answer with no sources. + """ + # Prepare inputs. + result = self._build_result("This answer is standalone.", []) + # Run test. + self.helper(result, result) + + def test2(self) -> None: + """ + Test a document-backed answer with valid sources. + """ + # Prepare inputs. + sources = [self._build_source("billing", 0, "download invoices")] + result = self._build_result("According to billing docs.", sources) + # Run test. + self.helper(result, result) + + def test3(self) -> None: + """ + Test that document-backed answers require sources. + """ + # Prepare inputs. + result = self._build_result("According to the document.", []) + # Prepare outputs. + expected = "You referenced docs/policies but did not include sources." + # Run test. + self.helper(result, expected) + + def test4(self) -> None: + """ + Test that too many sources raise a retry. + """ + # Prepare inputs. + sources = [ + self._build_source("doc1", 0, "quote1"), + self._build_source("doc2", 0, "quote2"), + self._build_source("doc3", 0, "quote3"), + self._build_source("doc4", 0, "quote4"), + ] + result = self._build_result("Standalone answer.", sources) + # Prepare outputs. + expected = "Too many sources. Max 3." + # Run test. + self.helper(result, expected) + + def test5(self) -> None: + """ + Test that duplicate sources raise a retry. + """ + # Prepare inputs. + sources = [ + self._build_source("doc1", 0, "quote"), + self._build_source("doc1", 0, "quote"), + ] + result = self._build_result("Standalone answer.", sources) + # Prepare outputs. + expected = "Duplicate sources. Keep sources unique." + # Run test. + self.helper(result, expected) + + @staticmethod + def _build_result( + answer: str, sources: list[SimpleNamespace] + ) -> SimpleNamespace: + """ + Build a validator input object. + + :param answer: answer text + :param sources: source references + :return: validator input + """ + result = SimpleNamespace(answer=answer, sources=sources) + return result + + @staticmethod + def _build_source( + doc_id: str, chunk_id: int, quote: str + ) -> SimpleNamespace: + """ + Build a source reference object. + + :param doc_id: document identifier + :param chunk_id: chunk identifier + :param quote: source quote + :return: source reference + """ + source = SimpleNamespace(doc_id=doc_id, chunk_id=chunk_id, quote=quote) + return source + + +# ############################################################################# +# Test_ask +# ############################################################################# + + +class Test_ask(hunitest.TestCase): + """ + Test async agent wrappers. + """ + + def test1(self) -> None: + """ + Test that `ask()` returns the agent output. + """ + # Prepare inputs. + deps = SimpleNamespace(name="deps") + expected = {"answer": "ok"} + + class _Agent: + async def run(self, query: str, deps: object) -> SimpleNamespace: + return SimpleNamespace(output=expected) + + agent = _Agent() + # Run test. + actual = asyncio.run(peu.ask("question", deps, agent)) + # Check outputs. + self.assertEqual(actual, expected) + + +# ############################################################################# +# Test_stream_demo +# ############################################################################# + + +class Test_stream_demo(hunitest.TestCase): + """ + Test streaming notebook output helpers. + """ + + def test1(self) -> None: + """ + Test fallback streaming for agents without `run_stream`. + """ + # Prepare inputs. + expected = "Unit tests matter." + + class _Agent: + async def run(self, query: str) -> SimpleNamespace: + return SimpleNamespace(output=expected) + + stream_agent = _Agent() + # Run test. + with mock.patch("builtins.print") as mock_print: + asyncio.run(peu.stream_demo(stream_agent)) + # Check outputs. + mock_print.assert_called_once_with(expected) + + def test2(self) -> None: + """ + Test streaming text chunks from `run_stream`. + """ + # Prepare inputs. + chunks = ["Unit ", "tests"] + + class _Stream: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def stream_text(self): + for chunk in chunks: + yield chunk + + class _Agent: + def run_stream(self, query: str) -> _Stream: + return _Stream() + + stream_agent = _Agent() + # Run test. + with mock.patch("builtins.print") as mock_print: + asyncio.run(peu.stream_demo(stream_agent)) + # Check outputs. + expected = [ + mock.call("Unit ", end="", flush=True), + mock.call("tests", end="", flush=True), + mock.call("\n"), + ] + self.assertEqual(mock_print.call_args_list, expected) + + +# ############################################################################# +# Test_in_scope +# ############################################################################# + + +class Test_in_scope(hunitest.TestCase): + """ + Test support-question guardrail classification. + """ + + def helper(self, question: str, expected: bool) -> None: + """ + Test helper for `in_scope()`. + + :param question: user question + :param expected: expected classification + """ + # Run test. + actual = peu.in_scope(question) + # Check outputs. + self.assertEqual(actual, expected) + + def test1(self) -> None: + """ + Test an in-scope billing question. + """ + # Prepare inputs. + question = "How do I download an invoice?" + # Prepare outputs. + expected = True + # Run test. + self.helper(question, expected) + + def test2(self) -> None: + """ + Test an out-of-scope creative question. + """ + # Prepare inputs. + question = "Write me a poem about the ocean." + # Prepare outputs. + expected = False + # Run test. + self.helper(question, expected) + + +# ############################################################################# +# Test_run_guarded +# ############################################################################# + + +class Test_run_guarded(hunitest.TestCase): + """ + Test guarded agent execution. + """ + + def test1(self) -> None: + """ + Test the out-of-scope guardrail response. + """ + # Prepare inputs. + answer_with_sources_cls = SimpleNamespace + deps = SimpleNamespace() + agent = SimpleNamespace() + # Run test. + actual = asyncio.run( + peu.run_guarded( + "Write me a poem about the ocean.", + deps, + agent, + answer_with_sources_cls, + ) + ) + # Check outputs. + self.assertEqual( + actual.answer, + "I can only help with Atlas product documentation and support questions.", + ) + self.assertEqual(len(actual.follow_up_questions), 1) + + def test2(self) -> None: + """ + Test delegating an in-scope question to the agent. + """ + # Prepare inputs. + expected = {"answer": "Atlas support answer"} + + class _Agent: + async def run( + self, question: str, deps: object, message_history: object = None + ) -> SimpleNamespace: + return SimpleNamespace(output=expected) + + agent = _Agent() + # Run test. + actual = asyncio.run( + peu.run_guarded( + "How do I contact Atlas support?", + SimpleNamespace(), + agent, + SimpleNamespace, + ) + ) + # Check outputs. + self.assertEqual(actual, expected) + + +# ############################################################################# +# Test_load_docs +# ############################################################################# + + +class Test_load_docs(hunitest.TestCase): + """ + Test loading markdown documents from disk. + """ + + def test1(self) -> None: + """ + Test loading and sorting markdown files. + """ + # Prepare inputs. + scratch_dir = pathlib.Path(self.get_scratch_space()) + (scratch_dir / "zeta.md").write_text("Zeta text", encoding="utf-8") + (scratch_dir / "alpha.md").write_text("Alpha text", encoding="utf-8") + # Run test. + actual = peu.load_docs(scratch_dir) + # Check outputs. + self.assert_equal( + str([doc["doc_id"] for doc in actual]), str(["alpha", "zeta"]) + ) + self.assertEqual(actual[0]["title"], "Alpha") + self.assertEqual(actual[1]["text"], "Zeta text") + + def test2(self) -> None: + """ + Test loading an empty directory. + """ + # Prepare inputs. + scratch_dir = pathlib.Path(self.get_scratch_space()) / "empty_docs" + os.makedirs(scratch_dir, exist_ok=True) + # Run test. + actual = peu.load_docs(scratch_dir) + # Check outputs. + self.assert_equal(str(actual), str([])) diff --git a/tutorials/tutorial_pydanticAI/utils.sh b/tutorials/tutorial_pydanticAI/utils.sh new file mode 100644 index 000000000..cc0ed8c4a --- /dev/null +++ b/tutorials/tutorial_pydanticAI/utils.sh @@ -0,0 +1,607 @@ +#!/bin/bash +# """ +# Utility functions for Docker container management. +# """ + + +# ############################################################################# +# General utilities +# ############################################################################# + + +run() { + # """ + # Execute a command with echo output. + # + # :param cmd: Command string to execute + # :return: Exit status of the executed command + # """ + cmd="$*" + echo "> $cmd" + eval "$cmd" +} + + +enable_verbose_mode() { + # """ + # Enable shell command tracing (set -x) when VERBOSE is set to 1. + # + # Reads the VERBOSE variable set by parse_docker_jupyter_args. + # Call this after parsing args to activate tracing for the rest of the script. + # """ + if [[ $VERBOSE == 1 ]]; then + set -x + fi +} + + +# ############################################################################# +# Argument parsing +# ############################################################################# + + +_print_default_help() { + # """ + # Print usage information and available default options for docker scripts. + # """ + echo "Usage: $(basename $0) [options]" + echo "" + echo "Options:" + echo " -f Force kill existing container with same name before starting" + echo " -h Print this help message and exit" + echo " -v Enable verbose output (set -x)" +} + + +parse_default_args() { + # """ + # Parse default command-line arguments for docker scripts. + # + # Sets VERBOSE and FORCE variables in the caller's scope. Enables set -x + # when -v is passed. Prints help and exits when -h is passed. + # Updates OPTIND so the caller can shift away processed arguments. + # + # :param @: command-line arguments forwarded from the calling script + # """ + VERBOSE=0 + FORCE=0 + while getopts "fhv" flag; do + case "${flag}" in + f) FORCE=1;; + h) _print_default_help; exit 0;; + v) VERBOSE=1;; + *) _print_default_help; exit 1;; + esac + done + enable_verbose_mode +} + + +_print_docker_jupyter_help() { + # """ + # Print usage information and available options for docker_jupyter.sh. + # """ + echo "Usage: $(basename $0) [options]" + echo "" + echo "Launch Jupyter Lab inside a Docker container." + echo "" + echo "Options:" + echo " -f Force kill existing container with same name before starting" + echo " -h Print this help message and exit" + echo " -p PORT Host port to forward to Jupyter Lab (default: 8888)" + echo " -u Enable vim keybindings in Jupyter Lab" + echo " -v Enable verbose output (set -x)" +} + + +parse_docker_jupyter_args() { + # """ + # Parse command-line arguments for docker_jupyter.sh. + # + # Sets JUPYTER_HOST_PORT, JUPYTER_USE_VIM, TARGET_DIR, VERBOSE, FORCE, and + # OLD_CMD_OPTS in the caller's scope. Enables set -x when -v is passed. + # Prints help and exits when -h is passed. + # + # :param @: command-line arguments forwarded from the calling script + # """ + # Set defaults. + JUPYTER_HOST_PORT=8888 + JUPYTER_USE_VIM=0 + VERBOSE=0 + FORCE=0 + # Save original args to pass through to run_jupyter.sh. + OLD_CMD_OPTS="$*" + # Parse options. + while getopts "fhp:uv" flag; do + case "${flag}" in + f) FORCE=1;; + h) _print_docker_jupyter_help; exit 0;; + p) JUPYTER_HOST_PORT=${OPTARG};; # Port for Jupyter Lab. + u) JUPYTER_USE_VIM=1;; # Enable vim bindings. + v) VERBOSE=1;; # Enable verbose output. + *) _print_docker_jupyter_help; exit 1;; + esac + done + # Enable command tracing if verbose mode is requested. + enable_verbose_mode +} + + +# ############################################################################# +# Docker image management +# ############################################################################# + + +get_docker_vars_script() { + # """ + # Load Docker variables from docker_name.sh script. + # + # :param script_path: Path to the script to determine the Docker configuration directory + # :return: Sources REPO_NAME, IMAGE_NAME, and FULL_IMAGE_NAME variables + # """ + local script_path=$1 + # Find the name of the container. + SCRIPT_DIR=$(dirname $script_path) + DOCKER_NAME="$SCRIPT_DIR/docker_name.sh" + if [[ ! -e $SCRIPT_DIR ]]; then + echo "Can't find $DOCKER_NAME" + exit -1 + fi; + source $DOCKER_NAME +} + + +print_docker_vars() { + # """ + # Print current Docker variables to stdout. + # """ + echo "REPO_NAME=$REPO_NAME" + echo "IMAGE_NAME=$IMAGE_NAME" + echo "FULL_IMAGE_NAME=$FULL_IMAGE_NAME" +} + + +build_container_image() { + # """ + # Build a Docker container image. + # + # Supports both single-architecture and multi-architecture builds. + # Creates temporary build directory, copies files, and builds the image. + # + # :param @: Additional options to pass to docker build/buildx build + # """ + echo "# ${FUNCNAME[0]} ..." + FULL_IMAGE_NAME=$REPO_NAME/$IMAGE_NAME + echo "FULL_IMAGE_NAME=$FULL_IMAGE_NAME" + # Prepare build area. + #tar -czh . | docker build $OPTS -t $IMAGE_NAME - + DIR="../tmp.build" + if [[ -d $DIR ]]; then + rm -rf $DIR + fi; + cp -Lr . $DIR || true + # Build container. + echo "DOCKER_BUILDKIT=$DOCKER_BUILDKIT" + echo "DOCKER_BUILD_MULTI_ARCH=$DOCKER_BUILD_MULTI_ARCH" + if [[ $DOCKER_BUILD_MULTI_ARCH != 1 ]]; then + # Build for a single architecture. + echo "Building for current architecture..." + OPTS="--progress plain $@" + (cd $DIR; docker build $OPTS -t $FULL_IMAGE_NAME . 2>&1 | tee ../docker_build.log; exit ${PIPESTATUS[0]}) + else + # Build for multiple architectures. + echo "Building for multiple architectures..." + OPTS="$@" + export DOCKER_CLI_EXPERIMENTAL=enabled + # Create a new builder. + #docker buildx rm --all-inactive --force + #docker buildx create --name mybuilder + #docker buildx use mybuilder + # Use the default builder. + docker buildx use multiarch + docker buildx inspect --bootstrap + # Note that one needs to push to the repo since otherwise it is not + # possible to keep multiple. + (cd $DIR; docker buildx build --push --platform linux/arm64,linux/amd64 $OPTS --tag $FULL_IMAGE_NAME . 2>&1 | tee ../docker_build.log; exit ${PIPESTATUS[0]}) + # Report the status. + docker buildx imagetools inspect $FULL_IMAGE_NAME + fi; + # Report build version. + if [ -f docker_build.version.log ]; then + rm docker_build.version.log + fi + (cd $DIR; docker run --rm -it -v $(pwd):/data $FULL_IMAGE_NAME bash -c "/data/version.sh") 2>&1 | tee docker_build.version.log + # + docker image ls $REPO_NAME/$IMAGE_NAME + rm -rf $DIR + echo "*****************************" + echo "SUCCESS" + echo "*****************************" +} + + +remove_container_image() { + # """ + # Remove Docker container image(s) matching the current configuration. + # """ + echo "# ${FUNCNAME[0]} ..." + FULL_IMAGE_NAME=$REPO_NAME/$IMAGE_NAME + echo "FULL_IMAGE_NAME=$FULL_IMAGE_NAME" + docker image ls | grep $FULL_IMAGE_NAME + docker image ls | grep $FULL_IMAGE_NAME | awk '{print $1}' | xargs -n 1 -t docker image rm -f + docker image ls + echo "${FUNCNAME[0]} ... done" +} + + +push_container_image() { + # """ + # Push Docker container image to registry. + # + # Authenticates using credentials from ~/.docker/passwd.$REPO_NAME.txt. + # """ + echo "# ${FUNCNAME[0]} ..." + FULL_IMAGE_NAME=$REPO_NAME/$IMAGE_NAME + echo "FULL_IMAGE_NAME=$FULL_IMAGE_NAME" + docker login --username $REPO_NAME --password-stdin <~/.docker/passwd.$REPO_NAME.txt + docker images $FULL_IMAGE_NAME + docker push $FULL_IMAGE_NAME + echo "${FUNCNAME[0]} ... done" +} + + +pull_container_image() { + # """ + # Pull Docker container image from registry. + # """ + echo "# ${FUNCNAME[0]} ..." + FULL_IMAGE_NAME=$REPO_NAME/$IMAGE_NAME + echo "FULL_IMAGE_NAME=$FULL_IMAGE_NAME" + docker pull $FULL_IMAGE_NAME + echo "${FUNCNAME[0]} ... done" +} + + +# ############################################################################# +# Docker container management +# ############################################################################# + + +kill_container() { + # """ + # Kill and remove Docker container(s) matching the current configuration. + # """ + echo "# ${FUNCNAME[0]} ..." + FULL_IMAGE_NAME=$REPO_NAME/$IMAGE_NAME + echo "FULL_IMAGE_NAME=$FULL_IMAGE_NAME" + docker container ls + # + CONTAINER_ID=$(docker container ls -a | grep $FULL_IMAGE_NAME | awk '{print $1}') + echo "CONTAINER_ID=$CONTAINER_ID" + if [[ ! -z $CONTAINER_ID ]]; then + docker container rm -f $CONTAINER_ID + docker container ls + fi; + echo "${FUNCNAME[0]} ... done" +} + + +kill_container_by_name() { + # """ + # Kill and remove a Docker container by its name. + # + # :param container_name: Name of the container to kill + # """ + local container_name=$1 + echo "# ${FUNCNAME[0]}: $container_name" + # Check if container exists (running or stopped). + local container_id=$(docker container ls -a --filter "name=^${container_name}$" --format "{{.ID}}") + if [[ -n $container_id ]]; then + echo "Killing container: $container_name (ID: $container_id)" + docker container rm -f $container_id + else + echo "Container '$container_name' not found" + fi + echo "${FUNCNAME[0]} ... done" +} + + +exec_container() { + # """ + # Execute bash shell in running Docker container. + # + # Opens an interactive bash session in the first container matching the + # current configuration. + # """ + echo "# ${FUNCNAME[0]} ..." + FULL_IMAGE_NAME=$REPO_NAME/$IMAGE_NAME + echo "FULL_IMAGE_NAME=$FULL_IMAGE_NAME" + docker container ls + # + CONTAINER_ID=$(docker container ls -a | grep $FULL_IMAGE_NAME | awk '{print $1}') + echo "CONTAINER_ID=$CONTAINER_ID" + docker exec -it $CONTAINER_ID bash + echo "${FUNCNAME[0]} ... done" +} + + +# ############################################################################# +# Docker common options +# ############################################################################# + + +get_docker_common_options() { + # """ + # Return docker run options common to all container types. + # + # Includes volume mount for the git root, plus environment variables for + # PYTHONPATH and host OS name. + # + # :return: docker run options string with volume mounts and env vars + # """ + echo "-v $GIT_ROOT:/git_root \ + -e PYTHONPATH=/git_root:/git_root/helpers_root:/git_root/msml610/tutorials \ + -e CSFY_GIT_ROOT_PATH=/git_root \ + -e CSFY_HOST_OS_NAME=$(uname -s) \ + -e CSFY_HOST_NAME=$(uname -n)" +} + + +# ############################################################################# +# Docker bash +# ############################################################################# + + +get_docker_bash_command() { + # """ + # Return the base docker run command for an interactive bash shell. + # + # :return: docker run command string with --rm and -ti flags + # """ + if [ -t 0 ]; then + echo "docker run --rm -ti" + else + echo "docker run --rm -i" + fi +} + + +get_docker_bash_options() { + # """ + # Return docker run options for a Docker container. + # + # :param container_name: Name for the Docker container + # :param port: Port number to forward (optional, skipped if empty) + # :param extra_opts: Additional docker run options (optional) + # :return: docker run options string with name, volume mounts, and env vars + # """ + local container_name=$1 + local port=$2 + local extra_opts=$3 + local port_opt="" + if [[ -n $port ]]; then + port_opt="-p $port:$port" + fi + echo "--name $container_name \ + $port_opt \ + $extra_opts \ + $(get_docker_common_options)" +} + + +# ############################################################################# +# Docker cmd +# ############################################################################# + + +get_docker_cmd_command() { + # """ + # Return the base docker run command for executing a non-interactive command. + # + # :return: docker run command string with --rm and -i flags + # """ + echo "docker run --rm -i" +} + + +# ############################################################################# +# Docker Jupyter +# ############################################################################# + + +get_docker_jupyter_command() { + # """ + # Return the base docker run command for running Jupyter Lab interactively. + # + # :return: docker run command string with --rm and -ti flags (if TTY available) + # """ + local docker_cmd="docker run --rm" + # Add interactive and TTY flags only if stdin is a TTY. + if [[ -t 0 ]]; then + docker_cmd="$docker_cmd -ti" + fi + echo "$docker_cmd" +} + + +get_docker_jupyter_options() { + # """ + # Return docker run options for a Jupyter Lab container. + # + # :param container_name: Name for the Docker container + # :param host_port: Host port to forward to container port 8888 + # :param jupyter_use_vim: 0 or 1 to enable vim bindings + # :return: docker run options string + # """ + local container_name=$1 + local host_port=$2 + local jupyter_use_vim=$3 + # Run as the current user when user is saggese. + if [[ "$(whoami)" == "saggese" ]]; then + echo "Overwriting jupyter_use_vim since user='saggese'" >&2 + jupyter_use_vim=1 + fi + echo "--name $container_name \ + -p $host_port:8888 \ + $(get_docker_common_options) \ + -e JUPYTER_USE_VIM=$jupyter_use_vim" +} + + +configure_jupyter_vim_keybindings() { + # """ + # Configure JupyterLab vim keybindings based on JUPYTER_USE_VIM env var. + # + # Reads JUPYTER_USE_VIM; if 1, verifies jupyterlab_vim is installed and + # writes enabled settings; otherwise writes disabled settings. + # """ + mkdir -p ~/.jupyter/lab/user-settings/@axlair/jupyterlab_vim + if [[ $JUPYTER_USE_VIM == 1 ]]; then + # Check that jupyterlab_vim is installed before trying to enable it. + if ! pip show jupyterlab_vim > /dev/null 2>&1; then + echo "ERROR: jupyterlab_vim is not installed but vim bindings were requested." + echo "Install it with: pip install jupyterlab_vim" + exit 1 + fi + echo "Enabling vim." + cat < ~/.jupyter/lab/user-settings/\@axlair/jupyterlab_vim/plugin.jupyterlab-settings +{ + "enabled": true, + "enabledInEditors": true, + "extraKeybindings": [], + "autosaveInterval": 6 +} +EOF + else + echo "Disabling vim." + cat < ~/.jupyter/lab/user-settings/\@axlair/jupyterlab_vim/plugin.jupyterlab-settings +{ + "enabled": false, + "enabledInEditors": false, + "extraKeybindings": [], + "autosaveInterval": 6 +} +EOF + fi; +} + + +configure_jupyter_notifications() { + # """ + # Disable JupyterLab news fetching and update checks. + # """ + mkdir -p ~/.jupyter/lab/user-settings/@jupyterlab/apputils-extension + cat < ~/.jupyter/lab/user-settings/\@jupyterlab/apputils-extension/notification.jupyterlab-settings +{ + // Notifications + // @jupyterlab/apputils-extension:notification + // Notifications settings. + + // Fetch official Jupyter news + // Whether to fetch news from the Jupyter news feed. If Always (`true`), it will make a request to a website. + "fetchNews": "false", + "checkForUpdates": false +} +EOF +} + + +configure_jupyter_autosave() { + # """ + # Configure JupyterLab global autosave interval to 6 seconds. + # """ + mkdir -p ~/.jupyter/lab/user-settings/@jupyterlab/docmanager-extension + cat < ~/.jupyter/lab/user-settings/\@jupyterlab/docmanager-extension/plugin.jupyterlab-settings +{ + "autosaveInterval": 6 +} +EOF +} + + +check_jupytext_installed() { + # """ + # Verify that jupytext is installed before starting Jupyter Lab. + # + # Jupytext is required for pair notebook/Python file functionality. + # Exits with error if jupytext is not installed. + # """ + if ! pip show jupytext > /dev/null 2>&1; then + echo "ERROR: jupytext is not installed but is required to run Jupyter Lab." + echo "Install it with: pip install jupytext" + exit 1 + fi +} + + +setup_jupyter_environment() { + # """ + # Configure Jupyter Lab environment before launching. + # + # Performs all necessary setup steps: + # - Configure vim keybindings + # - Disable notifications + # - Configure autosave interval + # - Verify jupytext is installed + # """ + configure_jupyter_vim_keybindings + configure_jupyter_notifications + configure_jupyter_autosave + check_jupytext_installed +} + + +get_jupyter_args() { + # """ + # Print the standard Jupyter Lab command-line arguments. + # + # :return: space-separated Jupyter Lab args for port 8888 with no browser, + # allow root, and no authentication + # """ + echo "--port=8888 --no-browser --ip=0.0.0.0 --allow-root --ServerApp.token='' --ServerApp.password=''" +} + + +get_run_jupyter_cmd() { + # """ + # Return the command to run run_jupyter.sh inside a container. + # + # Computes the script's path relative to GIT_ROOT and builds the + # corresponding /git_root/... path used inside the container. + # + # :param script_path: path of the calling script (pass ${BASH_SOURCE[0]}) + # :param cmd_opts: options to forward to run_jupyter.sh + # :return: full command string to run run_jupyter.sh + # """ + local script_path=$1 + local cmd_opts=$2 + local script_dir + script_dir=$(cd "$(dirname "$script_path")" && pwd) + local rel_dir="${script_dir#${GIT_ROOT}/}" + echo "/git_root/${rel_dir}/run_jupyter.sh $cmd_opts" +} + + +list_and_inspect_docker_image() { + # """ + # List available Docker images and inspect their architecture. + # + # Lists all images matching FULL_IMAGE_NAME and attempts to inspect + # their architecture using docker manifest inspect. + # """ + run "docker image ls $FULL_IMAGE_NAME" + (docker manifest inspect $FULL_IMAGE_NAME | grep arch) || true +} + + +kill_existing_container_if_forced() { + # """ + # Kill existing container if FORCE flag is set. + # + # If FORCE is set to 1, kills and removes the container with name + # CONTAINER_NAME. This is typically set by the -f flag. + # """ + if [[ $FORCE == 1 ]]; then + kill_container_by_name $CONTAINER_NAME + fi +}