-
Notifications
You must be signed in to change notification settings - Fork 920
feat(database): add connect_args support to SqlAlchemyAdapter #1861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
f9b16e5
c892265
3f53534
4f3a1bc
a7da9c7
1f98d50
f26b490
654a573
e1d313a
2de1bd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import os | ||
| import json | ||
| import pydantic | ||
| from typing import Union | ||
| from functools import lru_cache | ||
|
|
@@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings): | |
| db_username: Union[str, None] = None # "cognee" | ||
| db_password: Union[str, None] = None # "cognee" | ||
| db_provider: str = "sqlite" | ||
| database_connect_args: Union[str, None] = None | ||
|
|
||
| model_config = SettingsConfigDict(env_file=".env", extra="allow") | ||
|
|
||
|
|
@@ -30,6 +32,17 @@ def fill_derived(self): | |
| databases_directory_path = os.path.join(base_config.system_root_directory, "databases") | ||
| self.db_path = databases_directory_path | ||
|
|
||
| # Parse database_connect_args if provided as JSON string | ||
| if self.database_connect_args and isinstance(self.database_connect_args, str): | ||
| try: | ||
| parsed_args = json.loads(self.database_connect_args) | ||
| if isinstance(parsed_args, dict): | ||
| self.database_connect_args = parsed_args | ||
| else: | ||
| self.database_connect_args = {} | ||
| except json.JSONDecodeError: | ||
| self.database_connect_args = {} | ||
|
Comment on lines
+36
to
+44
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Inconsistent handling of empty strings. The parsing logic skips empty strings (line 36 checks Apply this diff to handle empty strings consistently: # Parse database_connect_args if provided as JSON string
-if self.database_connect_args and isinstance(self.database_connect_args, str):
+if self.database_connect_args is not None and isinstance(self.database_connect_args, str):
+ if not self.database_connect_args.strip():
+ self.database_connect_args = None
+ else:
- try:
- parsed_args = json.loads(self.database_connect_args)
- if isinstance(parsed_args, dict):
- self.database_connect_args = parsed_args
- else:
- self.database_connect_args = {}
- except json.JSONDecodeError:
- self.database_connect_args = {}
+ try:
+ parsed_args = json.loads(self.database_connect_args)
+ if isinstance(parsed_args, dict):
+ self.database_connect_args = parsed_args
+ else:
+ self.database_connect_args = {}
+ except json.JSONDecodeError:
+ self.database_connect_args = {}🤖 Prompt for AI Agents |
||
|
|
||
| return self | ||
|
|
||
| def to_dict(self) -> dict: | ||
|
|
@@ -40,7 +53,8 @@ def to_dict(self) -> dict: | |
| -------- | ||
|
|
||
| - dict: A dictionary containing database configuration settings including db_path, | ||
| db_name, db_host, db_port, db_username, db_password, and db_provider. | ||
| db_name, db_host, db_port, db_username, db_password, db_provider, and | ||
| database_connect_args. | ||
| """ | ||
| return { | ||
| "db_path": self.db_path, | ||
|
|
@@ -50,6 +64,7 @@ def to_dict(self) -> dict: | |
| "db_username": self.db_username, | ||
| "db_password": self.db_password, | ||
| "db_provider": self.db_provider, | ||
| "database_connect_args": self.database_connect_args, | ||
| } | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ def create_relational_engine( | |
| db_username: str, | ||
| db_password: str, | ||
| db_provider: str, | ||
| database_connect_args: dict = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Use The type annotation Apply this diff: +from typing import Optional
+
@lru_cache
def create_relational_engine(
db_path: str,
db_name: str,
db_host: str,
db_port: str,
db_username: str,
db_password: str,
db_provider: str,
- database_connect_args: dict = None,
+ database_connect_args: Optional[dict] = None,
):🤖 Prompt for AI Agents |
||
| ): | ||
| """ | ||
| Create a relational database engine based on the specified parameters. | ||
|
|
@@ -29,6 +30,7 @@ def create_relational_engine( | |
| - db_password (str): The password for database authentication, required for | ||
| PostgreSQL. | ||
| - db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres'). | ||
| - database_connect_args (dict, optional): Database driver connection arguments. | ||
|
|
||
| Returns: | ||
| -------- | ||
|
|
@@ -51,4 +53,4 @@ def create_relational_engine( | |
| "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality." | ||
| ) | ||
|
|
||
| return SQLAlchemyAdapter(connection_string) | ||
| return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,10 +29,31 @@ class SQLAlchemyAdapter: | |
| functions. | ||
| """ | ||
|
|
||
| def __init__(self, connection_string: str): | ||
| def __init__(self, connection_string: str, connect_args: dict = None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Use Similar to Apply this diff: +from typing import Optional
+
- def __init__(self, connection_string: str, connect_args: dict = None):
+ def __init__(self, connection_string: str, connect_args: Optional[dict] = None):
🤖 Prompt for AI Agents |
||
| """ | ||
| Initialize the SQLAlchemy adapter with connection settings. | ||
|
|
||
| Parameters: | ||
| ----------- | ||
| connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' | ||
| or 'postgresql://user:pass@host:port/db'). | ||
| connect_args (dict, optional): Database driver connection arguments. | ||
| Configuration is loaded from RelationalConfig.database_connect_args, which reads | ||
| from the DATABASE_CONNECT_ARGS environment variable. | ||
|
|
||
| Examples: | ||
| PostgreSQL with SSL: | ||
| DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}' | ||
|
|
||
| SQLite with custom timeout: | ||
| DATABASE_CONNECT_ARGS='{"timeout": 60}' | ||
| """ | ||
| self.db_path: str = None | ||
| self.db_uri: str = connection_string | ||
|
|
||
| # Use provided connect_args (already parsed from config) | ||
| final_connect_args = connect_args or {} | ||
|
|
||
| if "sqlite" in connection_string: | ||
| [prefix, db_path] = connection_string.split("///") | ||
| self.db_path = db_path | ||
|
|
@@ -53,7 +74,7 @@ def __init__(self, connection_string: str): | |
| self.engine = create_async_engine( | ||
| connection_string, | ||
| poolclass=NullPool, | ||
| connect_args={"timeout": 30}, | ||
| connect_args={**{"timeout": 30}, **final_connect_args}, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be no need to unpack two dictionaries here. We can simply put |
||
| ) | ||
| else: | ||
| self.engine = create_async_engine( | ||
|
|
@@ -63,6 +84,7 @@ def __init__(self, connection_string: str): | |
| pool_recycle=280, | ||
| pool_pre_ping=True, | ||
| pool_timeout=280, | ||
| connect_args=final_connect_args, | ||
| ) | ||
|
|
||
| self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| import os | ||
| from unittest.mock import patch | ||
| from cognee.infrastructure.databases.relational.config import RelationalConfig | ||
|
|
||
|
|
||
| class TestRelationalConfig: | ||
| """Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing.""" | ||
|
|
||
| def test_database_connect_args_valid_json_dict(self): | ||
| """Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict.""" | ||
| with patch.dict( | ||
| os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'} | ||
| ): | ||
| config = RelationalConfig() | ||
| assert config.database_connect_args == {"timeout": 60, "sslmode": "require"} | ||
|
|
||
| def test_database_connect_args_empty_string(self): | ||
| """Test that empty DATABASE_CONNECT_ARGS is handled correctly.""" | ||
| with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}): | ||
| config = RelationalConfig() | ||
| assert config.database_connect_args == "" | ||
|
|
||
| def test_database_connect_args_not_set(self): | ||
| """Test that missing DATABASE_CONNECT_ARGS results in None.""" | ||
| with patch.dict(os.environ, {}, clear=True): | ||
| config = RelationalConfig() | ||
| assert config.database_connect_args is None | ||
|
|
||
| def test_database_connect_args_invalid_json(self): | ||
| """Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict.""" | ||
| with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON | ||
| config = RelationalConfig() | ||
| assert config.database_connect_args == {} | ||
|
|
||
| def test_database_connect_args_non_dict_json(self): | ||
| """Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict.""" | ||
| with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): | ||
| config = RelationalConfig() | ||
| assert config.database_connect_args == {} | ||
|
|
||
| def test_database_connect_args_to_dict(self): | ||
| """Test that database_connect_args is included in to_dict() output.""" | ||
| with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): | ||
| config = RelationalConfig() | ||
| config_dict = config.to_dict() | ||
| assert "database_connect_args" in config_dict | ||
| assert config_dict["database_connect_args"] == {"timeout": 60} | ||
|
|
||
| def test_database_connect_args_integer_value(self): | ||
| """Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly.""" | ||
| with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}): | ||
| config = RelationalConfig() | ||
| assert config.database_connect_args == {"connect_timeout": 10} | ||
|
|
||
| def test_database_connect_args_mixed_types(self): | ||
| """Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly.""" | ||
| with patch.dict( | ||
| os.environ, | ||
| { | ||
| "DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}' | ||
| }, | ||
| ): | ||
| config = RelationalConfig() | ||
| assert config.database_connect_args == { | ||
| "timeout": 60, | ||
| "sslmode": "require", | ||
| "retries": 3, | ||
| "keepalive": True, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix type annotation to reflect runtime types.
The field
database_connect_argsis declared asUnion[str, None], but after thefill_derivedvalidator runs (lines 36-44), it can also be adict. This creates a type inconsistency.Apply this diff to fix the type annotation:
📝 Committable suggestion
🤖 Prompt for AI Agents