Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ DB_NAME=cognee_db
#DB_USERNAME=cognee
#DB_PASSWORD=cognee

# -- Advanced: Custom database connection arguments (optional) ---------------
# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc.
# Examples:
# For PostgreSQL with SSL:
# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
# For SQLite with custom timeout:
# DATABASE_CONNECT_ARGS='{"timeout": 60}'
#DATABASE_CONNECT_ARGS='{}'

################################################################################
# 🕸️ Graph Database settings
################################################################################
Expand Down
17 changes: 16 additions & 1 deletion cognee/infrastructure/databases/relational/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import json
import pydantic
from typing import Union
from functools import lru_cache
Expand All @@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
db_username: Union[str, None] = None # "cognee"
db_password: Union[str, None] = None # "cognee"
db_provider: str = "sqlite"
database_connect_args: Union[str, None] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix type annotation to reflect runtime types.

The field database_connect_args is declared as Union[str, None], but after the fill_derived validator runs (lines 36-44), it can also be a dict. This creates a type inconsistency.

Apply this diff to fix the type annotation:

-    database_connect_args: Union[str, None] = None
+    database_connect_args: Union[dict, str, None] = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
database_connect_args: Union[str, None] = None
database_connect_args: Union[dict, str, None] = None
🤖 Prompt for AI Agents
In cognee/infrastructure/databases/relational/config.py around line 23 (and
refer to the fill_derived validator at lines 36-44), the field
database_connect_args is annotated as Union[str, None] but at runtime the
validator can set it to a dict; update the type annotation to reflect
Optional[Union[str, Dict[str, Any]]] (or Optional[Union[str, dict]]) and add the
necessary typing imports (Dict and Any or use dict) so static types match
runtime values; ensure any references or mypy checks are adjusted accordingly.


model_config = SettingsConfigDict(env_file=".env", extra="allow")

Expand All @@ -30,6 +32,17 @@ def fill_derived(self):
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
self.db_path = databases_directory_path

# Parse database_connect_args if provided as JSON string
if self.database_connect_args and isinstance(self.database_connect_args, str):
try:
parsed_args = json.loads(self.database_connect_args)
if isinstance(parsed_args, dict):
self.database_connect_args = parsed_args
else:
self.database_connect_args = {}
except json.JSONDecodeError:
self.database_connect_args = {}
Comment on lines +36 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Inconsistent handling of empty strings.

The parsing logic skips empty strings (line 36 checks if self.database_connect_args), leaving them as empty strings rather than converting them to None or {}. This is inconsistent with the intent to provide a dict for connection args and could cause issues downstream.

Apply this diff to handle empty strings consistently:

 # Parse database_connect_args if provided as JSON string
-if self.database_connect_args and isinstance(self.database_connect_args, str):
+if self.database_connect_args is not None and isinstance(self.database_connect_args, str):
+    if not self.database_connect_args.strip():
+        self.database_connect_args = None
+    else:
-    try:
-        parsed_args = json.loads(self.database_connect_args)
-        if isinstance(parsed_args, dict):
-            self.database_connect_args = parsed_args
-        else:
-            self.database_connect_args = {}
-    except json.JSONDecodeError:
-        self.database_connect_args = {}
+        try:
+            parsed_args = json.loads(self.database_connect_args)
+            if isinstance(parsed_args, dict):
+                self.database_connect_args = parsed_args
+            else:
+                self.database_connect_args = {}
+        except json.JSONDecodeError:
+            self.database_connect_args = {}
🤖 Prompt for AI Agents
In cognee/infrastructure/databases/relational/config.py around lines 36 to 44,
the code skips empty strings because it checks `if self.database_connect_args`
before parsing, leaving "" unchanged; update the logic to first check if
`isinstance(self.database_connect_args, str)`, then strip the string and if it
is empty set `self.database_connect_args = {}`; otherwise attempt `json.loads`
and on success assign the dict (or `{}` if parsed value is not a dict), catching
`json.JSONDecodeError` and setting `{}` on error.


return self

def to_dict(self) -> dict:
Expand All @@ -40,7 +53,8 @@ def to_dict(self) -> dict:
--------

- dict: A dictionary containing database configuration settings including db_path,
db_name, db_host, db_port, db_username, db_password, and db_provider.
db_name, db_host, db_port, db_username, db_password, db_provider, and
database_connect_args.
"""
return {
"db_path": self.db_path,
Expand All @@ -50,6 +64,7 @@ def to_dict(self) -> dict:
"db_username": self.db_username,
"db_password": self.db_password,
"db_provider": self.db_provider,
"database_connect_args": self.database_connect_args,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def create_relational_engine(
db_username: str,
db_password: str,
db_provider: str,
database_connect_args: dict = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use Optional[dict] for clarity and correctness.

The type annotation dict = None is technically valid but less explicit than Optional[dict] = None or dict | None = None. Per PEP 484 conventions, using Optional or the union syntax is preferred for optional parameters.

Apply this diff:

+from typing import Optional
+
 @lru_cache
 def create_relational_engine(
     db_path: str,
     db_name: str,
     db_host: str,
     db_port: str,
     db_username: str,
     db_password: str,
     db_provider: str,
-    database_connect_args: dict = None,
+    database_connect_args: Optional[dict] = None,
 ):
🤖 Prompt for AI Agents
In cognee/infrastructure/databases/relational/create_relational_engine.py around
line 14, the parameter annotation currently uses "dict = None"; change it to an
explicit optional type such as "Optional[dict] = None" (or "dict | None = None"
for Python 3.10+) and add the corresponding import from typing (from typing
import Optional) if not already present, to make the optional nature explicit
and follow PEP 484 conventions.

):
"""
Create a relational database engine based on the specified parameters.
Expand All @@ -29,6 +30,7 @@ def create_relational_engine(
- db_password (str): The password for database authentication, required for
PostgreSQL.
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
- database_connect_args (dict, optional): Database driver connection arguments.

Returns:
--------
Expand All @@ -51,4 +53,4 @@ def create_relational_engine(
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
)

return SQLAlchemyAdapter(connection_string)
return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,31 @@ class SQLAlchemyAdapter:
functions.
"""

def __init__(self, connection_string: str):
def __init__(self, connection_string: str, connect_args: dict = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use Optional[dict] for type annotation consistency.

Similar to create_relational_engine.py, the type annotation dict = None should use Optional[dict] = None for clarity and PEP 484 compliance.

Apply this diff:

+from typing import Optional
+
-    def __init__(self, connection_string: str, connect_args: dict = None):
+    def __init__(self, connection_string: str, connect_args: Optional[dict] = None):

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py
around line 32, the __init__ parameter connect_args is annotated as dict = None;
change the annotation to Optional[dict] = None for PEP 484 consistency and match
create_relational_engine.py. Add the necessary import from typing (Optional) at
top of the file if missing, then update the signature to use Optional[dict] =
None.

"""
Initialize the SQLAlchemy adapter with connection settings.

Parameters:
-----------
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
or 'postgresql://user:pass@host:port/db').
connect_args (dict, optional): Database driver connection arguments.
Configuration is loaded from RelationalConfig.database_connect_args, which reads
from the DATABASE_CONNECT_ARGS environment variable.

Examples:
PostgreSQL with SSL:
DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'

SQLite with custom timeout:
DATABASE_CONNECT_ARGS='{"timeout": 60}'
"""
self.db_path: str = None
self.db_uri: str = connection_string

# Use provided connect_args (already parsed from config)
final_connect_args = connect_args or {}

if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")
self.db_path = db_path
Expand All @@ -53,7 +74,7 @@ def __init__(self, connection_string: str):
self.engine = create_async_engine(
connection_string,
poolclass=NullPool,
connect_args={"timeout": 30},
connect_args={**{"timeout": 30}, **final_connect_args},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no need to unpack two dictionaries here. We can simply put
connect_args={"timeout": 30, **final_connect_args}.
Also worth noting that, with this order in the dictionary, if there is another timeout key in the final_connect_args, the first one will be overwritten. Here we probably want that, but just a note.

)
else:
self.engine = create_async_engine(
Expand All @@ -63,6 +84,7 @@ def __init__(self, connection_string: str):
pool_recycle=280,
pool_pre_ping=True,
pool_timeout=280,
connect_args=final_connect_args,
)

self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
from unittest.mock import patch
from cognee.infrastructure.databases.relational.config import RelationalConfig


class TestRelationalConfig:
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""

def test_database_connect_args_valid_json_dict(self):
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
with patch.dict(
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
):
config = RelationalConfig()
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}

def test_database_connect_args_empty_string(self):
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
config = RelationalConfig()
assert config.database_connect_args == ""

def test_database_connect_args_not_set(self):
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
with patch.dict(os.environ, {}, clear=True):
config = RelationalConfig()
assert config.database_connect_args is None

def test_database_connect_args_invalid_json(self):
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
config = RelationalConfig()
assert config.database_connect_args == {}

def test_database_connect_args_non_dict_json(self):
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
config = RelationalConfig()
assert config.database_connect_args == {}

def test_database_connect_args_to_dict(self):
"""Test that database_connect_args is included in to_dict() output."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
config = RelationalConfig()
config_dict = config.to_dict()
assert "database_connect_args" in config_dict
assert config_dict["database_connect_args"] == {"timeout": 60}

def test_database_connect_args_integer_value(self):
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
config = RelationalConfig()
assert config.database_connect_args == {"connect_timeout": 10}

def test_database_connect_args_mixed_types(self):
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
with patch.dict(
os.environ,
{
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
},
):
config = RelationalConfig()
assert config.database_connect_args == {
"timeout": 60,
"sslmode": "require",
"retries": 3,
"keepalive": True,
}