diff --git a/src/praisonai-agents/praisonaiagents/rag/context.py b/src/praisonai-agents/praisonaiagents/rag/context.py index 11d7bef21..b6820dbf5 100644 --- a/src/praisonai-agents/praisonaiagents/rag/context.py +++ b/src/praisonai-agents/praisonaiagents/rag/context.py @@ -6,7 +6,13 @@ """ import hashlib -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union + +if TYPE_CHECKING: + from ..knowledge.models import SearchResultItem + + +ResultItem = Union[Dict[str, Any], "SearchResultItem"] def _estimate_tokens(text: str) -> int: @@ -19,6 +25,39 @@ def _estimate_tokens(text: str) -> int: return len(text) // 4 + 1 +def _extract_value(item: ResultItem, key: str, default: Any = None) -> Any: + """ + Extract value from either dict or object (e.g., SearchResultItem). + + Args: + item: Dictionary or object with attributes + key: Key/attribute name to extract + default: Default value if key not found + + Returns: + Extracted value or default + """ + if isinstance(item, dict): + return item.get(key, default) + else: + return getattr(item, key, default) + + +def _extract_metadata_value( + item: ResultItem, + metadata: Dict[str, Any], + key: str, + default: Any = "", +) -> Any: + """ + Extract metadata value with fallback to top-level item attribute. + """ + value = metadata.get(key) + if value is None: + return _extract_value(item, key, default) + return value + + def _chunk_hash(text: str, source: Optional[str] = None) -> str: """ Generate stable hash for chunk deduplication. @@ -35,31 +74,38 @@ def _chunk_hash(text: str, source: Optional[str] = None) -> str: def deduplicate_chunks( - results: List[Dict[str, Any]], + results: List[ResultItem], similarity_threshold: float = 0.9, -) -> List[Dict[str, Any]]: +) -> List[ResultItem]: """ Deduplicate chunks by content hash. + Handles both dict format and SearchResultItem objects. + Args: - results: List of search results with 'text' or 'memory' key + results: List of search results (dicts or SearchResultItem objects) similarity_threshold: Not used currently (hash-based dedup) Returns: Deduplicated list of results """ seen_hashes: Set[str] = set() - unique_results: List[Dict[str, Any]] = [] + unique_results: List[ResultItem] = [] for result in results: # Skip None results if result is None: continue - # Handle different result formats - text = result.get("text") or result.get("memory", "") + + # Handle different result formats (dict or SearchResultItem) + text = _extract_value(result, "text") or _extract_value(result, "memory", "") + # CRITICAL: Handle metadata=None from mem0 - ensure always dict - metadata = result.get("metadata") or {} - source = metadata.get("source", "") + metadata = _extract_value(result, "metadata") or {} + if not isinstance(metadata, dict): + metadata = {} + + source = _extract_metadata_value(result, metadata, "source", "") chunk_id = _chunk_hash(text, source) @@ -106,17 +152,19 @@ def truncate_context( def build_context( - results: List[Dict[str, Any]], + results: List[ResultItem], max_tokens: int = 4000, deduplicate: bool = True, separator: str = "\n\n---\n\n", include_source: bool = True, -) -> Tuple[str, List[Dict[str, Any]]]: +) -> Tuple[str, List[ResultItem]]: """ Build context string from retrieval results. + Handles both dict format and SearchResultItem objects. + Args: - results: List of search results + results: List of search results (dicts or SearchResultItem objects) max_tokens: Maximum tokens for context deduplicate: Whether to deduplicate chunks separator: Separator between chunks @@ -133,7 +181,7 @@ def build_context( results = deduplicate_chunks(results) context_parts: List[str] = [] - used_results: List[Dict[str, Any]] = [] + used_results: List[ResultItem] = [] current_tokens = 0 separator_tokens = _estimate_tokens(separator) @@ -141,17 +189,21 @@ def build_context( # Skip None results if result is None: continue - # Handle different result formats - text = result.get("text") or result.get("memory", "") + + # Handle different result formats (dict or SearchResultItem) + text = _extract_value(result, "text") or _extract_value(result, "memory", "") if not text: continue # Build chunk text with optional source if include_source: # CRITICAL: Handle metadata=None from mem0 - ensure always dict - metadata = result.get("metadata") or {} - source = metadata.get("source", "") - filename = metadata.get("filename", "") + metadata = _extract_value(result, "metadata") or {} + if not isinstance(metadata, dict): + metadata = {} + + source = _extract_metadata_value(result, metadata, "source", "") + filename = _extract_metadata_value(result, metadata, "filename", "") source_label = filename or source or f"Source {i + 1}" chunk_text = f"[{source_label}]\n{text}" else: @@ -194,11 +246,14 @@ def __init__( def build( self, - results: List[Dict[str, Any]], + results: List[ResultItem], max_tokens: int = 4000, deduplicate: bool = True, ) -> str: - """Build context string from results.""" + """Build context string from results. + + Handles both dict format and SearchResultItem objects. + """ context, _ = build_context( results=results, max_tokens=max_tokens, diff --git a/src/praisonai-agents/tests/unit/rag/test_context.py b/src/praisonai-agents/tests/unit/rag/test_context.py index e84649f1e..a95313ad3 100644 --- a/src/praisonai-agents/tests/unit/rag/test_context.py +++ b/src/praisonai-agents/tests/unit/rag/test_context.py @@ -223,3 +223,188 @@ def test_default_builder_custom_separator(self): context = builder.build(results) assert "\n\n" in context + + +class TestExtractValue: + """Tests for the _extract_value helper function.""" + + def test_extract_value_from_dict(self): + """Test extracting value from dictionary.""" + from praisonaiagents.rag.context import _extract_value + + item = {"text": "test content", "metadata": {"source": "test.pdf"}} + + assert _extract_value(item, "text") == "test content" + assert _extract_value(item, "metadata")["source"] == "test.pdf" + assert _extract_value(item, "nonexistent", "default") == "default" + + def test_extract_value_from_searchresultitem(self): + """Test extracting value from SearchResultItem object.""" + from praisonaiagents.rag.context import _extract_value + from praisonaiagents.knowledge.models import SearchResultItem + + item = SearchResultItem( + text="test content", + source="test.pdf", + filename="test.pdf", + metadata={"extra": "data"} + ) + + assert _extract_value(item, "text") == "test content" + assert _extract_value(item, "source") == "test.pdf" + assert _extract_value(item, "filename") == "test.pdf" + assert _extract_value(item, "metadata")["extra"] == "data" + assert _extract_value(item, "nonexistent", "default") == "default" + + +class TestExtractMetadataValue: + """Tests for the _extract_metadata_value helper function.""" + + def test_extract_metadata_value_from_metadata(self): + """Test extracting value from metadata dict.""" + from praisonaiagents.rag.context import _extract_metadata_value + from praisonaiagents.knowledge.models import SearchResultItem + + item = SearchResultItem(text="content", source="fallback.pdf") + metadata = {"source": "metadata.pdf"} + + # Should prefer metadata over top-level + result = _extract_metadata_value(item, metadata, "source", "") + assert result == "metadata.pdf" + + def test_extract_metadata_value_fallback_to_toplevel(self): + """Test fallback to top-level when metadata doesn't have key.""" + from praisonaiagents.rag.context import _extract_metadata_value + from praisonaiagents.knowledge.models import SearchResultItem + + item = SearchResultItem(text="content", source="toplevel.pdf") + metadata = {} # Empty metadata + + # Should fall back to top-level source + result = _extract_metadata_value(item, metadata, "source", "") + assert result == "toplevel.pdf" + + def test_extract_metadata_value_with_default(self): + """Test fallback to default when neither metadata nor top-level have key.""" + from praisonaiagents.rag.context import _extract_metadata_value + from praisonaiagents.knowledge.models import SearchResultItem + + item = SearchResultItem(text="content") # No source + metadata = {} # No source in metadata + + # Should use default + result = _extract_metadata_value(item, metadata, "source", "default.pdf") + assert result == "default.pdf" + + +class TestSearchResultItemSupport: + """Tests for SearchResultItem object support in context functions.""" + + def test_deduplicate_chunks_with_searchresultitem(self): + """Test deduplication with SearchResultItem objects.""" + from praisonaiagents.rag.context import deduplicate_chunks + from praisonaiagents.knowledge.models import SearchResultItem + + results = [ + SearchResultItem(text="Same content", source="a.pdf"), + SearchResultItem(text="Same content", source="a.pdf"), # Duplicate + SearchResultItem(text="Different content", source="b.pdf"), + ] + + deduped = deduplicate_chunks(results) + assert len(deduped) == 2 # One duplicate should be removed + + def test_deduplicate_chunks_mixed_formats(self): + """Test deduplication with mixed dict and SearchResultItem formats.""" + from praisonaiagents.rag.context import deduplicate_chunks + from praisonaiagents.knowledge.models import SearchResultItem + + results = [ + {"text": "Content A", "metadata": {"source": "a.pdf"}}, + SearchResultItem(text="Content A", source="a.pdf"), # Should be deduped + SearchResultItem(text="Content B", source="b.pdf"), + ] + + deduped = deduplicate_chunks(results) + assert len(deduped) == 2 # One duplicate should be removed + + def test_deduplicate_with_top_level_source(self): + """Test deduplication considers top-level source from SearchResultItem.""" + from praisonaiagents.rag.context import deduplicate_chunks + from praisonaiagents.knowledge.models import SearchResultItem + + results = [ + SearchResultItem(text="Same text", source="source1.pdf"), + SearchResultItem(text="Same text", source="source2.pdf"), + ] + + deduped = deduplicate_chunks(results) + assert len(deduped) == 2 # Different sources should not be deduped + + def test_build_context_with_searchresultitem(self): + """Test building context with SearchResultItem objects.""" + from praisonaiagents.rag.context import build_context + from praisonaiagents.knowledge.models import SearchResultItem + + results = [ + SearchResultItem( + text="First content", + source="source1.pdf", + filename="file1.pdf" + ), + SearchResultItem( + text="Second content", + source="source2.pdf", + filename="file2.pdf" + ), + ] + + context, used = build_context(results) + + assert "First content" in context + assert "Second content" in context + assert "file1.pdf" in context # Should use filename + assert "file2.pdf" in context + assert len(used) == 2 + + def test_build_context_source_fallback(self): + """Test that build_context falls back to top-level source when metadata is empty.""" + from praisonaiagents.rag.context import build_context + from praisonaiagents.knowledge.models import SearchResultItem + + # SearchResultItem with source but empty metadata + results = [ + SearchResultItem( + text="Content with source", + source="fallback.pdf", + metadata={} # Empty metadata, should fall back to top-level source + ) + ] + + context, used = build_context(results, include_source=True) + + assert "Content with source" in context + assert "fallback.pdf" in context # Should use top-level source + assert len(used) == 1 + + def test_build_context_mixed_formats_with_sources(self): + """Test building context with mixed dict/SearchResultItem formats.""" + from praisonaiagents.rag.context import build_context + from praisonaiagents.knowledge.models import SearchResultItem + + results = [ + {"text": "Dict content", "metadata": {"filename": "dict.pdf"}}, + SearchResultItem( + text="Object content", + filename="object.pdf", + metadata={} + ), + ] + + context, used = build_context(results, include_source=True) + + assert "Dict content" in context + assert "Object content" in context + assert "dict.pdf" in context + assert "object.pdf" in context # Should use top-level filename + assert len(used) == 2 diff --git a/src/praisonai-agents/tests/unit/rag/test_context_normalization.py b/src/praisonai-agents/tests/unit/rag/test_context_normalization.py index cb13a1bc6..68bf24688 100644 --- a/src/praisonai-agents/tests/unit/rag/test_context_normalization.py +++ b/src/praisonai-agents/tests/unit/rag/test_context_normalization.py @@ -8,6 +8,7 @@ deduplicate_chunks, build_context, ) +from praisonaiagents.knowledge.models import SearchResultItem class TestDeduplicateChunksNormalization: @@ -122,3 +123,25 @@ def test_empty_text_filtered(self): context, used = build_context(results) assert len(used) == 1 assert "valid content" in context + + +class TestSearchResultItemCompatibility: + """Tests for SearchResultItem object compatibility.""" + + def test_deduplicate_uses_source_attribute_fallback(self): + """Different source attributes should produce unique chunks.""" + results = [ + SearchResultItem(text="same content", source="a.pdf"), + SearchResultItem(text="same content", source="b.pdf"), + ] + deduped = deduplicate_chunks(results) + assert len(deduped) == 2 + + def test_build_context_uses_filename_attribute_fallback(self): + """Source label should use object filename/source attributes.""" + results = [ + SearchResultItem(text="chunk", filename="doc.md"), + ] + context, used = build_context(results, include_source=True) + assert len(used) == 1 + assert "doc.md" in context