Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 120 additions & 35 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@

logger = logging.getLogger(__name__)

# ref: https://www.sqlite.org/limits.html
# Lowest default maximum is 999, intentionally setting it to half
_SQLITE_MAX_BIND_VARS = 500

Model = TypeVar("Model")

Expand Down Expand Up @@ -361,10 +364,9 @@ def get_scores(
Returns:
Sequence[Score]: A list of Score objects that match the specified filters.
"""
# Build base conditions without score_ids, we will handle that with batching
conditions: list[Any] = []

if score_ids:
conditions.append(ScoreEntry.id.in_(score_ids))
if score_type:
conditions.append(ScoreEntry.score_type == score_type)
if score_category:
Expand All @@ -374,6 +376,18 @@ def get_scores(
if sent_before:
conditions.append(ScoreEntry.timestamp <= sent_before)

# Handle score_ids with batching to avoid SQLite bind variable limits
if score_ids:
all_entries: list[ScoreEntry] = []
for i in range(0, len(score_ids), _SQLITE_MAX_BIND_VARS):
batch = score_ids[i : i + _SQLITE_MAX_BIND_VARS]
batch_conditions = conditions + [ScoreEntry.id.in_(batch)]
batch_entries: Sequence[ScoreEntry] = self._query_entries(
ScoreEntry, conditions=and_(*batch_conditions)
)
all_entries.extend(batch_entries)
return [entry.get_score() for entry in all_entries]

if not conditions:
return []

Expand Down Expand Up @@ -532,16 +546,14 @@ def get_message_pieces(
Exception: If there is an error retrieving the prompts,
an exception is logged and an empty list is returned.
"""
# Build base conditions (without parameters that may need batching)
conditions = []
if attack_id:
conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id)))
if role:
conditions.append(PromptMemoryEntry.role == role)
if conversation_id:
conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id))
if prompt_ids:
prompt_ids = [str(pi) for pi in prompt_ids]
conditions.append(PromptMemoryEntry.id.in_(prompt_ids))
if labels:
conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels))
if prompt_metadata:
Expand All @@ -550,21 +562,59 @@ def get_message_pieces(
conditions.append(PromptMemoryEntry.timestamp >= sent_after)
if sent_before:
conditions.append(PromptMemoryEntry.timestamp <= sent_before)
if original_values:
conditions.append(PromptMemoryEntry.original_value.in_(original_values))
if converted_values:
conditions.append(PromptMemoryEntry.converted_value.in_(converted_values))
if data_type:
conditions.append(PromptMemoryEntry.converted_value_data_type == data_type)
if not_data_type:
conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type)
if converted_value_sha256:

# Identify which parameter needs batching (prioritize the one provided)
batch_param = None
batch_values = None
batch_column = None

if prompt_ids:
batch_param = "prompt_ids"
batch_values = [str(pi) for pi in prompt_ids]
batch_column = PromptMemoryEntry.id
elif original_values and len(original_values) > _SQLITE_MAX_BIND_VARS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original code, it's possible to filter by all of these, not just one of them. Your changes assume it's at most one (see "elif").

batch_param = "original_values"
batch_values = list(original_values)
batch_column = PromptMemoryEntry.original_value
elif converted_values and len(converted_values) > _SQLITE_MAX_BIND_VARS:
batch_param = "converted_values"
batch_values = list(converted_values)
batch_column = PromptMemoryEntry.converted_value
elif converted_value_sha256 and len(converted_value_sha256) > _SQLITE_MAX_BIND_VARS:
batch_param = "converted_value_sha256"
batch_values = list(converted_value_sha256)
batch_column = PromptMemoryEntry.converted_value_sha256

# Add non-batched IN conditions
if original_values and batch_param != "original_values":
conditions.append(PromptMemoryEntry.original_value.in_(original_values))
if converted_values and batch_param != "converted_values":
conditions.append(PromptMemoryEntry.converted_value.in_(converted_values))
if converted_value_sha256 and batch_param != "converted_value_sha256":
conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256))

try:
memory_entries: Sequence[PromptMemoryEntry] = self._query_entries(
PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True
)
if batch_values:
all_entries: MutableSequence[PromptMemoryEntry] = []
for i in range(0, len(batch_values), _SQLITE_MAX_BIND_VARS):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code seems to repeat several times. Is there a simple way to make this into a generic method that we can reuse?

batch = batch_values[i : i + _SQLITE_MAX_BIND_VARS]
batch_conditions = conditions + [batch_column.in_(batch)]
batch_entries: Sequence[PromptMemoryEntry] = self._query_entries(
PromptMemoryEntry,
conditions=and_(*batch_conditions) if batch_conditions else None,
join_scores=True,
)
all_entries.extend(batch_entries)
memory_entries = all_entries
else:
memory_entries = self._query_entries(
PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True
)

message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries]
return sort_message_pieces(message_pieces=message_pieces)
except Exception as e:
Expand Down Expand Up @@ -1238,34 +1288,62 @@ def get_attack_results(
Returns:
Sequence[AttackResult]: A list of AttackResult objects that match the specified filters.
"""
# Build base conditions (without parameters that may need batching)
conditions: list[ColumnElement[bool]] = []

if attack_result_ids is not None:
if len(attack_result_ids) == 0:
# Empty list means no results
return []
conditions.append(AttackResultEntry.id.in_(attack_result_ids))
if conversation_id:
conditions.append(AttackResultEntry.conversation_id == conversation_id)
if objective:
conditions.append(AttackResultEntry.objective.contains(objective))

if objective_sha256:
conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256))
if outcome:
conditions.append(AttackResultEntry.outcome == outcome)

if targeted_harm_categories:
# Use database-specific JSON query method
conditions.append(
self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories)
)

if labels:
# Use database-specific JSON query method
conditions.append(self._get_attack_result_label_condition(labels=labels))

# Handle empty lists
if attack_result_ids is not None and len(attack_result_ids) == 0:
return []
if objective_sha256 is not None and len(objective_sha256) == 0:
return []

# Identify which parameter needs batching
batch_values = None
batch_column = None
batch_param_name = None

if attack_result_ids and len(attack_result_ids) > _SQLITE_MAX_BIND_VARS:
batch_values = list(attack_result_ids)
batch_column = AttackResultEntry.id
batch_param_name = "attack_result_ids"
elif objective_sha256 and len(objective_sha256) > _SQLITE_MAX_BIND_VARS:
batch_values = list(objective_sha256)
batch_column = AttackResultEntry.objective_sha256
batch_param_name = "objective_sha256"

# Add non-batched IN conditions
if attack_result_ids and batch_param_name != "attack_result_ids":
conditions.append(AttackResultEntry.id.in_(attack_result_ids))
if objective_sha256 and batch_param_name != "objective_sha256":
conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256))

try:
if batch_values:
all_entries: list[AttackResultEntry] = []
for i in range(0, len(batch_values), _SQLITE_MAX_BIND_VARS):
batch = batch_values[i : i + _SQLITE_MAX_BIND_VARS]
batch_conditions = list(conditions) + [batch_column.in_(batch)]
batch_entries: Sequence[AttackResultEntry] = self._query_entries(
AttackResultEntry, conditions=and_(*batch_conditions) if batch_conditions else None
)
all_entries.extend(batch_entries)
return [entry.get_attack_result() for entry in all_entries]

entries: Sequence[AttackResultEntry] = self._query_entries(
AttackResultEntry, conditions=and_(*conditions) if conditions else None
)
Expand Down Expand Up @@ -1426,18 +1504,13 @@ def get_scenario_results(
Returns:
Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters.
"""
conditions: list[ColumnElement[bool]] = []
# Handle empty list
if scenario_result_ids is not None and len(scenario_result_ids) == 0:
return []

if scenario_result_ids is not None:
if len(scenario_result_ids) == 0:
# Empty list means no results
return []
conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids))
conditions: list[ColumnElement[bool]] = []

if scenario_name:
# Normalize CLI snake_case names (e.g., "foundry" or "content_harms")
# to class names (e.g., "Foundry" or "ContentHarms")
# This allows users to query with either format
normalized_name = ScenarioResult.normalize_scenario_name(scenario_name)
conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name))

Expand Down Expand Up @@ -1466,9 +1539,21 @@ def get_scenario_results(
conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name))

try:
entries: Sequence[ScenarioResultEntry] = self._query_entries(
ScenarioResultEntry, conditions=and_(*conditions) if conditions else None
)
# Handle scenario_result_ids with batching if needed
if scenario_result_ids and len(scenario_result_ids) > _SQLITE_MAX_BIND_VARS:
all_entries: MutableSequence[ScenarioResultEntry] = []
for i in range(0, len(scenario_result_ids), _SQLITE_MAX_BIND_VARS):
batch = list(scenario_result_ids)[i : i + _SQLITE_MAX_BIND_VARS]
batch_conditions = list(conditions) + [ScenarioResultEntry.id.in_(batch)]
batch_entries: Sequence[ScenarioResultEntry] = self._query_entries(
ScenarioResultEntry, conditions=and_(*batch_conditions) if batch_conditions else None
)
all_entries.extend(batch_entries)
entries = all_entries
else:
if scenario_result_ids:
conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids))
entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None)

# Convert entries to ScenarioResults and populate attack_results efficiently
scenario_results = []
Expand Down
Loading