diff --git a/tests/conftest.py b/tests/conftest.py index 09167a0a..bec1f18e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,8 +20,9 @@ def load_regression_data(): + DATA_SIZE = 200 dataset = fetch_california_housing(data_home="data", as_frame=True) - df = dataset.frame.sample(5000) + df = dataset.frame.sample(DATA_SIZE) df["HouseAgeBin"] = pd.qcut(df["HouseAge"], q=4) df["HouseAgeBin"] = "age_" + df.HouseAgeBin.cat.codes.astype(str) test_idx = df.sample(int(0.2 * len(df)), random_state=42).index @@ -31,8 +32,9 @@ def load_regression_data(): def load_classification_data(): + DATA_SIZE = 200 dataset = fetch_covtype(data_home="data") - data = np.hstack([dataset.data, dataset.target.reshape(-1, 1)])[:10000, :] + data = np.hstack([dataset.data, dataset.target.reshape(-1, 1)])[:DATA_SIZE, :] col_names = [f"feature_{i}" for i in range(data.shape[-1])] col_names[-1] = "target" data = pd.DataFrame(data, columns=col_names)