Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ DB_NAME=cognee_db
#DB_USERNAME=cognee
#DB_PASSWORD=cognee

# -- Advanced: Custom database connection arguments (optional) ---------------
# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc.
# Examples:
# For PostgreSQL with SSL:
# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
# For SQLite with custom timeout:
# DATABASE_CONNECT_ARGS='{"timeout": 60}'
#DATABASE_CONNECT_ARGS='{}'

################################################################################
# 🕸️ Graph Database settings
################################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from os import path
import tempfile
from uuid import UUID
import json
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
Expand All @@ -29,10 +30,25 @@ class SQLAlchemyAdapter:
functions.
"""

def __init__(self, connection_string: str):
def __init__(self, connection_string: str, connect_args: Optional[dict] = None):
self.db_path: str = None
self.db_uri: str = connection_string

env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
if env_connect_args:
try:
env_connect_args = json.loads(env_connect_args)
if isinstance(env_connect_args, dict):
if connect_args is None:
connect_args = {}
connect_args.update(env_connect_args)
else:
logger.warning(
f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}"
)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}")

if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")
self.db_path = db_path
Expand All @@ -53,7 +69,7 @@ def __init__(self, connection_string: str):
self.engine = create_async_engine(
connection_string,
poolclass=NullPool,
connect_args={"timeout": 30},
connect_args={**(connect_args or {}), **{"timeout": 30}},
)
else:
self.engine = create_async_engine(
Expand All @@ -63,6 +79,7 @@ def __init__(self, connection_string: str):
pool_recycle=280,
pool_pre_ping=True,
pool_timeout=280,
connect_args=connect_args or {},
)

self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
Expand Down