Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions src/praisonai-agents/praisonaiagents/rag/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
"""

import hashlib
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

if TYPE_CHECKING:
from ..knowledge.models import SearchResultItem


ResultItem = Union[Dict[str, Any], "SearchResultItem"]


def _estimate_tokens(text: str) -> int:
Expand All @@ -19,6 +25,39 @@ def _estimate_tokens(text: str) -> int:
return len(text) // 4 + 1


def _extract_value(item: ResultItem, key: str, default: Any = None) -> Any:
"""
Extract value from either dict or object (e.g., SearchResultItem).

Args:
item: Dictionary or object with attributes
key: Key/attribute name to extract
default: Default value if key not found

Returns:
Extracted value or default
"""
if isinstance(item, dict):
return item.get(key, default)
else:
return getattr(item, key, default)


def _extract_metadata_value(
item: ResultItem,
metadata: Dict[str, Any],
key: str,
default: Any = "",
) -> Any:
"""
Extract metadata value with fallback to top-level item attribute.
"""
value = metadata.get(key)
if value is None:
return _extract_value(item, key, default)
return value


def _chunk_hash(text: str, source: Optional[str] = None) -> str:
"""
Generate stable hash for chunk deduplication.
Expand All @@ -35,31 +74,38 @@ def _chunk_hash(text: str, source: Optional[str] = None) -> str:


def deduplicate_chunks(
results: List[Dict[str, Any]],
results: List[ResultItem],
similarity_threshold: float = 0.9,
) -> List[Dict[str, Any]]:
) -> List[ResultItem]:
"""
Deduplicate chunks by content hash.

Handles both dict format and SearchResultItem objects.

Args:
results: List of search results with 'text' or 'memory' key
results: List of search results (dicts or SearchResultItem objects)
similarity_threshold: Not used currently (hash-based dedup)

Returns:
Deduplicated list of results
"""
seen_hashes: Set[str] = set()
unique_results: List[Dict[str, Any]] = []
unique_results: List[ResultItem] = []

for result in results:
# Skip None results
if result is None:
continue
# Handle different result formats
text = result.get("text") or result.get("memory", "")

# Handle different result formats (dict or SearchResultItem)
text = _extract_value(result, "text") or _extract_value(result, "memory", "")

# CRITICAL: Handle metadata=None from mem0 - ensure always dict
metadata = result.get("metadata") or {}
source = metadata.get("source", "")
metadata = _extract_value(result, "metadata") or {}
if not isinstance(metadata, dict):
metadata = {}

source = _extract_metadata_value(result, metadata, "source", "")

chunk_id = _chunk_hash(text, source)
Comment thread
greptile-apps[bot] marked this conversation as resolved.

Expand Down Expand Up @@ -106,17 +152,19 @@ def truncate_context(


def build_context(
results: List[Dict[str, Any]],
results: List[ResultItem],
max_tokens: int = 4000,
deduplicate: bool = True,
separator: str = "\n\n---\n\n",
include_source: bool = True,
) -> Tuple[str, List[Dict[str, Any]]]:
) -> Tuple[str, List[ResultItem]]:
"""
Build context string from retrieval results.

Handles both dict format and SearchResultItem objects.

Args:
results: List of search results
results: List of search results (dicts or SearchResultItem objects)
max_tokens: Maximum tokens for context
deduplicate: Whether to deduplicate chunks
separator: Separator between chunks
Expand All @@ -133,25 +181,29 @@ def build_context(
results = deduplicate_chunks(results)

context_parts: List[str] = []
used_results: List[Dict[str, Any]] = []
used_results: List[ResultItem] = []
current_tokens = 0
separator_tokens = _estimate_tokens(separator)

for i, result in enumerate(results):
# Skip None results
if result is None:
continue
# Handle different result formats
text = result.get("text") or result.get("memory", "")

# Handle different result formats (dict or SearchResultItem)
text = _extract_value(result, "text") or _extract_value(result, "memory", "")
if not text:
continue

# Build chunk text with optional source
if include_source:
# CRITICAL: Handle metadata=None from mem0 - ensure always dict
metadata = result.get("metadata") or {}
source = metadata.get("source", "")
filename = metadata.get("filename", "")
metadata = _extract_value(result, "metadata") or {}
if not isinstance(metadata, dict):
metadata = {}

source = _extract_metadata_value(result, metadata, "source", "")
filename = _extract_metadata_value(result, metadata, "filename", "")
source_label = filename or source or f"Source {i + 1}"
chunk_text = f"[{source_label}]\n{text}"
else:
Expand Down Expand Up @@ -194,11 +246,14 @@ def __init__(

def build(
self,
results: List[Dict[str, Any]],
results: List[ResultItem],
max_tokens: int = 4000,
deduplicate: bool = True,
) -> str:
"""Build context string from results."""
"""Build context string from results.

Handles both dict format and SearchResultItem objects.
"""
context, _ = build_context(
results=results,
max_tokens=max_tokens,
Expand Down
185 changes: 185 additions & 0 deletions src/praisonai-agents/tests/unit/rag/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,188 @@ def test_default_builder_custom_separator(self):

context = builder.build(results)
assert "\n\n" in context


class TestExtractValue:
"""Tests for the _extract_value helper function."""

def test_extract_value_from_dict(self):
"""Test extracting value from dictionary."""
from praisonaiagents.rag.context import _extract_value

item = {"text": "test content", "metadata": {"source": "test.pdf"}}

assert _extract_value(item, "text") == "test content"
assert _extract_value(item, "metadata")["source"] == "test.pdf"
assert _extract_value(item, "nonexistent", "default") == "default"

def test_extract_value_from_searchresultitem(self):
"""Test extracting value from SearchResultItem object."""
from praisonaiagents.rag.context import _extract_value
from praisonaiagents.knowledge.models import SearchResultItem

item = SearchResultItem(
text="test content",
source="test.pdf",
filename="test.pdf",
metadata={"extra": "data"}
)

assert _extract_value(item, "text") == "test content"
assert _extract_value(item, "source") == "test.pdf"
assert _extract_value(item, "filename") == "test.pdf"
assert _extract_value(item, "metadata")["extra"] == "data"
assert _extract_value(item, "nonexistent", "default") == "default"


class TestExtractMetadataValue:
"""Tests for the _extract_metadata_value helper function."""

def test_extract_metadata_value_from_metadata(self):
"""Test extracting value from metadata dict."""
from praisonaiagents.rag.context import _extract_metadata_value
from praisonaiagents.knowledge.models import SearchResultItem

item = SearchResultItem(text="content", source="fallback.pdf")
metadata = {"source": "metadata.pdf"}

# Should prefer metadata over top-level
result = _extract_metadata_value(item, metadata, "source", "")
assert result == "metadata.pdf"

def test_extract_metadata_value_fallback_to_toplevel(self):
"""Test fallback to top-level when metadata doesn't have key."""
from praisonaiagents.rag.context import _extract_metadata_value
from praisonaiagents.knowledge.models import SearchResultItem

item = SearchResultItem(text="content", source="toplevel.pdf")
metadata = {} # Empty metadata

# Should fall back to top-level source
result = _extract_metadata_value(item, metadata, "source", "")
assert result == "toplevel.pdf"

def test_extract_metadata_value_with_default(self):
"""Test fallback to default when neither metadata nor top-level have key."""
from praisonaiagents.rag.context import _extract_metadata_value
from praisonaiagents.knowledge.models import SearchResultItem

item = SearchResultItem(text="content") # No source
metadata = {} # No source in metadata

# Should use default
result = _extract_metadata_value(item, metadata, "source", "default.pdf")
assert result == "default.pdf"


class TestSearchResultItemSupport:
"""Tests for SearchResultItem object support in context functions."""

def test_deduplicate_chunks_with_searchresultitem(self):
"""Test deduplication with SearchResultItem objects."""
from praisonaiagents.rag.context import deduplicate_chunks
from praisonaiagents.knowledge.models import SearchResultItem

results = [
SearchResultItem(text="Same content", source="a.pdf"),
SearchResultItem(text="Same content", source="a.pdf"), # Duplicate
SearchResultItem(text="Different content", source="b.pdf"),
]

deduped = deduplicate_chunks(results)
assert len(deduped) == 2 # One duplicate should be removed

def test_deduplicate_chunks_mixed_formats(self):
"""Test deduplication with mixed dict and SearchResultItem formats."""
from praisonaiagents.rag.context import deduplicate_chunks
from praisonaiagents.knowledge.models import SearchResultItem

results = [
{"text": "Content A", "metadata": {"source": "a.pdf"}},
SearchResultItem(text="Content A", source="a.pdf"), # Should be deduped
SearchResultItem(text="Content B", source="b.pdf"),
]

deduped = deduplicate_chunks(results)
assert len(deduped) == 2 # One duplicate should be removed

def test_deduplicate_with_top_level_source(self):
"""Test deduplication considers top-level source from SearchResultItem."""
from praisonaiagents.rag.context import deduplicate_chunks
from praisonaiagents.knowledge.models import SearchResultItem

results = [
SearchResultItem(text="Same text", source="source1.pdf"),
SearchResultItem(text="Same text", source="source2.pdf"),
]

deduped = deduplicate_chunks(results)
assert len(deduped) == 2 # Different sources should not be deduped

def test_build_context_with_searchresultitem(self):
"""Test building context with SearchResultItem objects."""
from praisonaiagents.rag.context import build_context
from praisonaiagents.knowledge.models import SearchResultItem

results = [
SearchResultItem(
text="First content",
source="source1.pdf",
filename="file1.pdf"
),
SearchResultItem(
text="Second content",
source="source2.pdf",
filename="file2.pdf"
),
]

context, used = build_context(results)

assert "First content" in context
assert "Second content" in context
assert "file1.pdf" in context # Should use filename
assert "file2.pdf" in context
assert len(used) == 2

def test_build_context_source_fallback(self):
"""Test that build_context falls back to top-level source when metadata is empty."""
from praisonaiagents.rag.context import build_context
from praisonaiagents.knowledge.models import SearchResultItem

# SearchResultItem with source but empty metadata
results = [
SearchResultItem(
text="Content with source",
source="fallback.pdf",
metadata={} # Empty metadata, should fall back to top-level source
)
]

context, used = build_context(results, include_source=True)

assert "Content with source" in context
assert "fallback.pdf" in context # Should use top-level source
assert len(used) == 1

def test_build_context_mixed_formats_with_sources(self):
"""Test building context with mixed dict/SearchResultItem formats."""
from praisonaiagents.rag.context import build_context
from praisonaiagents.knowledge.models import SearchResultItem

results = [
{"text": "Dict content", "metadata": {"filename": "dict.pdf"}},
SearchResultItem(
text="Object content",
filename="object.pdf",
metadata={}
),
]

context, used = build_context(results, include_source=True)

assert "Dict content" in context
assert "Object content" in context
assert "dict.pdf" in context
assert "object.pdf" in context # Should use top-level filename
assert len(used) == 2
Loading
Loading