Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/cldk_workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ on:
- 'recipes/CodeLLM_DevKit/code_summarization.ipynb'
- 'recipes/CodeLLM_DevKit/generate_unit_tests.ipynb'
- 'recipes/CodeLLM_DevKit/validating_code_translation.ipynb'
- 'recipes/Text_to_SQL/Text_to_SQL_1a.ipynb'
pull_request:
branches:
- main
paths:
- 'recipes/CodeLLM_DevKit/code_summarization.ipynb'
- 'recipes/CodeLLM_DevKit/generate_unit_tests.ipynb'
- 'recipes/CodeLLM_DevKit/validating_code_translation.ipynb'

- 'recipes/Text_to_SQL/Text_to_SQL_1a.ipynb'
jobs:
test-java-notebooks:
runs-on: ubuntu-latest
Expand Down
264 changes: 264 additions & 0 deletions recipes/Text_to_SQL/Text_to_SQL_1a.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6c017e1f",
"metadata": {},
"source": [
"# Getting Started with Granite Code"
]
},
{
"cell_type": "markdown",
"id": "e7d00394",
"metadata": {},
"source": [
"# Introduction - Text to SQL"
]
},
{
"cell_type": "markdown",
"id": "52237aef",
"metadata": {},
"source": [
"In this notebook, you will learn how to convert natural language text to SQL through LLM. DB2 is used in this example and Langchain APIs are used to connect to database and execute queries.\n",
"\n",
"Note: This notebook was tested in Windows OS"
]
},
{
"cell_type": "markdown",
"id": "872eb65b",
"metadata": {},
"source": [
"## Pre-requisites\n",
"\n",
"1. Python 3.11 or later\n",
"2. DB2 driver (DB2 is used in this example)\n",
"3. Install `langchain` and `replicate` Python packages using the following command in your terminal or command prompt:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33fb3190-59fb-4edd-82dd-f20f6eab3e47",
"metadata": {},
"outputs": [],
"source": [
"!pip install langchain replicate"
]
},
{
"cell_type": "markdown",
"id": "c53aa7dc",
"metadata": {},
"source": [
"## Getting the variable from .env file \n",
"\n",
"This cell outlines how to load environment variables from a `.env` file and verify the availability of an API token. Create '.env' file with these variables\n",
"\n",
"REPLICATE_API_TOKEN= \"replicate token\"\n",
"\n",
"DATABASE_URI=\"database URI\"\n",
"\n",
"\n",
"Here are the various formats of different Databases URIs to connect \n",
"\n",
"DB2 - db2+ibm_db://username:password@hostname:port/database\n",
"\n",
"Oracle - oracle+cx_oracle://username:password@hostname:port/?service_name=service\n",
"\n",
"PostgreSQL - postgresql://username:password@hostname:port/database"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa4562d3",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from dotenv import load_dotenv\n",
"\n",
"# Load environment variables from the .env file\n",
"load_dotenv()\n",
"\n",
"# Retrieve the API token directly from environment variables\n",
"replicate_api_token = os.getenv('REPLICATE_API_TOKEN')\n",
"\n",
"if replicate_api_token:\n",
" print('API token loaded successfully:', replicate_api_token)\n",
"else:\n",
" print('Failed to load API token. Please check your .env file.')"
]
},
{
"cell_type": "markdown",
"id": "91e647d9",
"metadata": {},
"source": [
"## Initialize the Model \n",
"In this section, we specify the model ID used to invoke specific models from IBM Granite on the Replicate platform. Depending on the model's capacity and the task complexity, you can choose between models with different instruction capacities like `8b` or `20b`.\n",
"### Model Retrieval\n",
"\n",
"Here, we retrieve the model using the find_langchain_model function. This function is designed to locate and initialize models based on the platform (\"Replicate\") and the specified model_id."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9dcd744c",
"metadata": {},
"outputs": [],
"source": [
"from ibm_granite_community.langchain_utils import find_langchain_model\n",
"model_id = \"ibm-granite/granite-8b-code-instruct-128k\"\n",
"# model_id = \"ibm-granite/granite-20b-code-instruct-8k\"\n",
"\n",
"granite_via_replicate = find_langchain_model(platform=\"Replicate\", model_id=model_id)"
]
},
{
"cell_type": "markdown",
"id": "6ebef218",
"metadata": {},
"source": [
"## Connecting to Database and generating the SQL with a language model \n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "cd036c20",
"metadata": {},
"source": [
"NOTE : Please download drivers of database you're using and try to install and make note of the path as its been used in the script. Below cell will interact with a DB2 database and set the schema."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "126ebe75",
"metadata": {},
"outputs": [],
"source": [
"#import re\n",
"from langchain_community.utilities import SQLDatabase\n",
"import os\n",
"import ibm_db_dbi\n",
"\n",
"# Add the directory containing the DB2 DLLs to the search path\n",
"os.add_dll_directory('C:\\\\Program Files\\\\IBM\\\\SQLLIB\\\\BIN')\n",
"\n",
"# Initialize the SQLDatabase\n",
"db = SQLDatabase.from_uri(\"<DATABASE_URI>\")\n",
"\n",
"# The function get_schema() dynamically fetches the current schema to ensure the prompt is accurate and up-to-date\n",
"\n",
"def get_schema():\n",
" return db.get_table_info()\n",
"\n",
"try:\n",
" db.run(\"SET SCHEMA 'OMDB'\") # Replacing 'execute' with 'run'\n",
"except ibm_db_dbi.ProgrammingError as pe:\n",
" print(f\"SQL Programming Error: {pe}\")\n",
"except ibm_db_dbi.OperationalError as oe:\n",
" print(f\"Operational Error: {oe}\")\n",
"except Exception as e:\n",
" print(f\"General SQL Error: {e}\")\n",
"\n",
"# Fetch schema information\n",
"try:\n",
" schema_info = db.get_table_info() # Using get_table_info to fetch schema info\n",
" print(schema_info)\n",
"except AttributeError:\n",
" print(\"Method get_schema() not found in SQLDatabase class.\")\n",
"except Exception as e:\n",
" print(f\"Error fetching schema information: {e}\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "047ed309",
"metadata": {},
"source": [
"## SQL Query Generation and Execution Documentation\n",
"\n",
"A prompt is constructed to instruct the language model on what is expected. This prompt includes the current schema of the database and explicitly asks for an SQL query that answers the given question. \n",
"\n",
"Generates the SQL Query based on a language model, catering to specific database-related questions. It ensures that the queries are not only generated but also validated and executed securely."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c5b9da",
"metadata": {},
"outputs": [],
"source": [
"# Define a question relevant to the database for example, for an order management system, the questions could be 'Show all orders', 'what is the total count of orders'\n",
"question = \"Show all orders from yfs_order_header\"\n",
"\n",
"# Create a prompt for the model to generate an SQL query\n",
"prompt = f\"\"\"Based on the table schema below, write an SQL query that would answer the user's question; just return the SQL query and nothing else.\n",
"Schema:\n",
"{get_schema()}\n",
"\n",
"Question: {question}\n",
"\n",
"SQL Query:\"\"\"\n",
"\n",
"print(\"Prompt for the model:\", prompt)\n",
"\n",
"# Invoke the language model to generate the SQL query\n",
"try:\n",
" answer = granite_via_replicate.invoke(prompt)\n",
" # Use regular expression to find a common SQL pattern\n",
" match = re.search(r'\\b(SELECT|INSERT|UPDATE|DELETE)\\b.*?;', answer, re.IGNORECASE | re.DOTALL)\n",
" if match:\n",
" cleaned_answer = match.group(0).strip()\n",
" else:\n",
" cleaned_answer = \"No valid SQL query found.\"\n",
" print(\"Generated SQL Query:\", cleaned_answer)\n",
"except Exception as e:\n",
" print(\"Error invoking the model:\", e)\n",
"\n",
"# Validate and execute the SQL query if it is a valid string\n",
"if cleaned_answer and cleaned_answer != \"No valid SQL query found.\":\n",
" print(\"Final SQL Query to execute:\", cleaned_answer)\n",
" try:\n",
" result = db.run(cleaned_answer) # Execute the cleaned SQL command and fetch results\n",
" print(\"Query Results:\", result)\n",
" except ibm_db_dbi.ProgrammingError as pe:\n",
" print(\"SQL Programming Error:\", pe)\n",
" except Exception as e:\n",
" print(\"Error executing the SQL query:\", e)\n",
"else:\n",
" print(cleaned_answer)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}