Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/analysis/plotly/tests/test_marginal_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self) -> None:
self.experiment.trials[i].mark_running(no_runner_required=True)
self.experiment.attach_data(
Data(
pd.DataFrame(
df=pd.DataFrame(
{
"trial_index": [i] * num_arms,
"arm_name": [f"0_{j}" for j in range(num_arms)],
Expand Down
4 changes: 1 addition & 3 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, TYPE_CHECKING

from ax.core.arm import Arm
from ax.core.data import Data, sort_by_trial_index_and_arm_name
from ax.core.data import Data
from ax.core.evaluations_to_data import raw_evaluations_to_data
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.metric import Metric, MetricFetchResult
Expand Down Expand Up @@ -442,8 +442,6 @@ def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data
data = Metric._unwrap_trial_data_multi(
results=self.fetch_data_results(metrics=metrics, **kwargs)
)
if not data.has_step_column:
data.full_df = sort_by_trial_index_and_arm_name(data.full_df)

return data

Expand Down
207 changes: 166 additions & 41 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from __future__ import annotations

import itertools
from bisect import bisect_right
from collections.abc import Iterable
from copy import deepcopy
from functools import cached_property
from io import StringIO
from logging import Logger
from typing import Any, TypeVar
from typing import Any

import numpy as np
import numpy.typing as npt
Expand All @@ -34,11 +35,40 @@

logger: Logger = get_logger(__name__)

TData = TypeVar("TData", bound="Data")
DF_REPR_MAX_LENGTH = 1000
MAP_KEY = "step"


class DataRow:
def __init__(
self,
trial_index: int,
arm_name: str,
metric_name: str,
metric_signature: str,
mean: float,
se: float,
step: float | None = None,
start_time: int | None = None,
end_time: int | None = None,
n: int | None = None,
) -> None:
self.trial_index: int = trial_index
self.arm_name: str = arm_name

self.metric_name: str = metric_name
self.metric_signature: str = metric_signature

self.mean: float = mean
self.se: float = se

self.step: float | None = step

self.start_time: int | None = start_time
self.end_time: int | None = end_time
self.n: int | None = n


class Data(Base, SerializationMixin):
"""Class storing numerical data for an experiment.

Expand Down Expand Up @@ -102,8 +132,6 @@ class Data(Base, SerializationMixin):
"start_time": pd.Timestamp,
"end_time": pd.Timestamp,
"n": int,
"frac_nonnull": np.float64,
"random_split": int,
MAP_KEY: float,
}

Expand All @@ -116,16 +144,19 @@ class Data(Base, SerializationMixin):
"metric_signature",
]

full_df: pd.DataFrame
_data_rows: list[DataRow]

def __init__(
self: TData,
self,
data_rows: Iterable[DataRow] | None = None,
df: pd.DataFrame | None = None,
_skip_ordering_and_validation: bool = False,
) -> None:
"""Initialize a ``Data`` object from the given DataFrame.

Args:
data_rows: Iterable of DataRows. If provided, this will be used as the
source of truth for Data, over df.
df: DataFrame with underlying data, and required columns. Data must
be unique at the level of ("trial_index", "arm_name",
"metric_name"), plus "step" if a "step" column is present. A
Expand All @@ -136,31 +167,92 @@ def __init__(
Intended only for use in `Data.filter`, where the contents
of the DataFrame are known to be ordered and valid.
"""
if df is None:
# Initialize with barebones DF with expected dtypes
self.full_df = pd.DataFrame.from_dict(
if data_rows is not None:
self._data_rows = list(data_rows)
elif df is not None:
# Unroll the df into a list of DataRows
if missing_columns := self.REQUIRED_COLUMNS - {*df.columns}:
raise ValueError(
f"Dataframe must contain required columns {list(missing_columns)}."
)

self._data_rows = [
DataRow(
# pyre-ignore[16] Intentional unsafe namedtuple access
trial_index=row.trial_index,
# pyre-ignore[16] Intentional unsafe namedtuple access
arm_name=row.arm_name,
# pyre-ignore[16] Intentional unsafe namedtuple access
metric_name=row.metric_name,
# pyre-ignore[16] Intentional unsafe namedtuple access
metric_signature=row.metric_signature,
# pyre-ignore[16] Intentional unsafe namedtuple access
mean=row.mean,
# pyre-ignore[16] Intentional unsafe namedtuple access
se=row.sem,
step=getattr(row, "step", None),
start_time=getattr(row, "start_time", None),
end_time=getattr(row, "end_time", None),
n=getattr(row, "n", None),
)
# Using itertuples() instead of iterrows() for speed
for row in df.itertuples()
]
else:
self._data_rows = []

self._memo_df: pd.DataFrame | None = None
self.has_step_column: bool = any(
row.step is not None for row in self._data_rows
)

@property
def empty(self) -> bool:
"""Whether the data is empty."""
return len(self._data_rows) == 0

@cached_property
def full_df(self) -> pd.DataFrame:
"""
Convert the DataRows into a pandas DataFrame. If step, start_time, or end_time
is None for all rows the column will be elided.
"""
if len(self._data_rows) == 0:
return pd.DataFrame.from_dict(
{
col: pd.Series([], dtype=self.COLUMN_DATA_TYPES[col])
for col in self.REQUIRED_COLUMNS
}
)
elif _skip_ordering_and_validation:
self.full_df = df
else:
columns = set(df.columns)
missing_columns = self.REQUIRED_COLUMNS - columns
if missing_columns:
raise ValueError(
f"Dataframe must contain required columns {list(missing_columns)}."
)
# Drop rows where every input is null. Since `dropna` can be slow, first
# check trial index to see if dropping nulls might be needed.
if df["trial_index"].isnull().any():
df = df.dropna(axis=0, how="all", ignore_index=True)
df = self._safecast_df(df=df)
self.full_df = self._get_df_with_cols_in_expected_order(df=df)
self._memo_df = None
self.has_step_column = MAP_KEY in self.full_df.columns

# Detect whether any of the optional attributes are present and should be
# included as columns in the full DataFrame.
include_step = any(row.step is not None for row in self._data_rows)
include_start_time = any(row.start_time is not None for row in self._data_rows)
include_end_time = any(row.end_time is not None for row in self._data_rows)
include_n = any(row.n is not None for row in self._data_rows)

records = [
{
"trial_index": row.trial_index,
"arm_name": row.arm_name,
"metric_name": row.metric_name,
"metric_signature": row.metric_signature,
"mean": row.mean,
"sem": row.se,
**({"step": row.step} if include_step else {}),
**({"start_time": row.start_time} if include_start_time else {}),
**({"end_time": row.end_time} if include_end_time else {}),
**({"n": row.n} if include_n else {}),
}
for row in self._data_rows
]

return self._get_df_with_cols_in_expected_order(
df=self._safecast_df(
df=pd.DataFrame.from_records(records),
),
)

@classmethod
def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -175,7 +267,7 @@ def _get_df_with_cols_in_expected_order(cls, df: pd.DataFrame) -> pd.DataFrame:
return df

@classmethod
def _safecast_df(cls: type[TData], df: pd.DataFrame) -> pd.DataFrame:
def _safecast_df(cls, df: pd.DataFrame) -> pd.DataFrame:
"""Function for safely casting df to standard data types.

Needed because numpy does not support NaNs in integer arrays.
Expand Down Expand Up @@ -255,7 +347,7 @@ def df(self) -> pd.DataFrame:
return self._memo_df

# Case: Empty data
if self.full_df.empty:
if self.empty:
return self.full_df

idxs = (
Expand All @@ -275,14 +367,14 @@ def df(self) -> pd.DataFrame:
return self._memo_df

@classmethod
def from_multiple_data(cls: type[TData], data: Iterable[Data]) -> TData:
def from_multiple_data(cls, data: Iterable[Data]) -> Data:
"""Combines multiple objects into one (with the concatenated
underlying dataframe).

Args:
data: Iterable of Ax objects of this class to combine.
"""
dfs = [datum.full_df for datum in data if not datum.full_df.empty]
dfs = [datum.full_df for datum in data if not datum.empty]

if len(dfs) == 0:
return cls()
Expand All @@ -302,7 +394,10 @@ def metric_names(self) -> set[str]:
"""Set of metric names that appear in the underlying dataframe of
this object.
"""
return set() if self.df.empty else set(self.df["metric_name"].values)
if self.empty:
return set()

return {row.metric_name for row in self._data_rows}

@property
def metric_signatures(self) -> set[str]:
Expand Down Expand Up @@ -339,21 +434,21 @@ def filter(
_skip_ordering_and_validation=True,
)

def clone(self: TData) -> TData:
def clone(self) -> Data:
"""Returns a new Data object with the same underlying dataframe."""
return self.__class__(df=deepcopy(self.full_df))

def __eq__(self, o: Data) -> bool:
return type(self) is type(o) and dataframe_equals(self.full_df, o.full_df)

def relativize(
self: TData,
self,
status_quo_name: str = "status_quo",
as_percent: bool = False,
include_sq: bool = False,
bias_correction: bool = True,
control_as_constant: bool = False,
) -> TData:
) -> Data:
"""Relativize a data object w.r.t. a status_quo arm.

Args:
Expand Down Expand Up @@ -391,11 +486,10 @@ def relativize(
@cached_property
def trial_indices(self) -> set[int]:
"""Return the set of trial indices in the data."""
if self._memo_df is not None:
# Use a smaller df if available
return set(self.df["trial_index"].unique())
# If no small df is available, use the full df
return set(self.full_df["trial_index"].unique())
if self.empty:
return set()

return {row.trial_index for row in self._data_rows}

def latest(self, rows_per_group: int = 1) -> Data:
"""Return a new Data with the most recently observed `rows_per_group`
Expand Down Expand Up @@ -437,12 +531,12 @@ def latest(self, rows_per_group: int = 1) -> Data:
)

def subsample(
self: TData,
self,
keep_every: int | None = None,
limit_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
include_first_last: bool = True,
) -> TData:
) -> Data:
"""Return a new Data that subsamples the `MAP_KEY` column in an
equally-spaced manner. This function considers only the relative ordering
of the `MAP_KEY` values, making it most suitable when these values are
Expand Down Expand Up @@ -504,6 +598,37 @@ def subsample(
return self.__class__(df=subsampled_df)


def combine_data_rows_favoring_recent(
last_rows: Iterable[DataRow], new_rows: Iterable[DataRow]
) -> list[DataRow]:
"""Combine last_rows and new_rows.

Deduplicate in favor of new_rows when there are multiple observations with
the same "trial_index", "metric_name", "arm_name", and "step".

Args:
last_rows: The rows of data currently attached to a trial
new_rows: A list of rows containing new data to be attached
"""

deduped: dict[tuple[int, str, str, float | None], DataRow] = {}

# Loop over all rows without creating a new list in memory
for row in itertools.chain(last_rows, new_rows):
# NaN must be treated specially since NaN != NaN
if row.step is not None and np.isnan(row.step):
step_key = None
else:
step_key = row.step

key = (row.trial_index, row.metric_name, row.arm_name, step_key)
deduped[key] = row

return list(deduped.values())


# This function is only used in ax/storage and can be removed
# once storage is refactored to use DataRows.
def combine_dfs_favoring_recent(
last_df: pd.DataFrame, new_df: pd.DataFrame
) -> pd.DataFrame:
Expand Down
Loading
Loading