diff --git a/.gitignore b/.gitignore index 742930a33..a7733b802 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,12 @@ new_polytope_venv *.json venv_python3_11 *.txt -tests/data \ No newline at end of file +tests/data +venv_gj_iterator +new_venv_gj_iterator +**/build +*.so +*.lock +**/_version.py +rust_deployment_venv +**/target \ No newline at end of file diff --git a/polytope_feature/datacube/backends/datacube.py b/polytope_feature/datacube/backends/datacube.py index e8e41db0b..76910e296 100644 --- a/polytope_feature/datacube/backends/datacube.py +++ b/polytope_feature/datacube/backends/datacube.py @@ -128,6 +128,7 @@ def get_indices(self, path: DatacubePath, axis, lower, upper, method=None): e.g. returns integer discrete points between two floats """ path = self.fit_path(path) + print(path) indexes = axis.find_indexes(path, self) idx_between = axis.find_indices_between(indexes, lower, upper, self, method) @@ -166,6 +167,12 @@ def create(datacube, config={}, axis_options={}, compressed_axes_options=[], alt datacube, config, axis_options, compressed_axes_options, alternative_axes, context ) return fdbdatacube + if type(datacube).__name__ == "QubedDatacube": + from .qubed import QubedDatacube + # TODO: here we create the qubeddatacube twice..., which we do not want + qubed_datacube = QubedDatacube(datacube.q, datacube.datacube_axes, datacube.datacube_transformations, + config, axis_options, compressed_axes_options, alternative_axes, context) + return qubed_datacube def check_branching_axes(self, request): pass diff --git a/polytope_feature/datacube/backends/fdb.py b/polytope_feature/datacube/backends/fdb.py index 64304e379..b69138c8d 100644 --- a/polytope_feature/datacube/backends/fdb.py +++ b/polytope_feature/datacube/backends/fdb.py @@ -140,7 +140,7 @@ def get(self, requests: TensorIndexTree, context=None): logging.debug("The requests we give GribJump are: %s", printed_list_to_gj) logging.info("Requests given to GribJump extract for %s", context) try: - output_values = self.gj.extract(complete_list_complete_uncompressed_requests, context) + iterator = self.gj.extract(complete_list_complete_uncompressed_requests, context) except Exception as e: if "BadValue: Grid hash mismatch" in str(e): logging.info("Error is: %s", e) @@ -152,10 +152,7 @@ def get(self, requests: TensorIndexTree, context=None): raise e logging.info("Requests extracted from GribJump for %s", context) - if logging.root.level <= logging.DEBUG: - printed_output_values = output_values[::1000] - logging.debug("GribJump outputs: %s", printed_output_values) - self.assign_fdb_output_to_nodes(output_values, complete_fdb_decoding_info) + self.assign_fdb_output_to_nodes(iterator, complete_fdb_decoding_info) def get_fdb_requests( self, @@ -321,9 +318,8 @@ def get_last_layer_before_leaf(self, requests, leaf_path, current_idx, fdb_range fdb_range_n[i].append(c) return (current_idx, fdb_range_n) - def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): - for k in range(len(output_values)): - request_output_values = output_values[k] + def assign_fdb_output_to_nodes(self, output_iterator, fdb_requests_decoding_info): + for k, result in enumerate(output_iterator): ( original_indices, fdb_node_ranges, @@ -331,13 +327,12 @@ def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): sorted_fdb_range_nodes = [fdb_node_ranges[i] for i in original_indices] for i in range(len(sorted_fdb_range_nodes)): n = sorted_fdb_range_nodes[i][0] - if len(request_output_values[0]) == 0: + if len(result.values) == 0: # If we are here, no data was found for this path in the fdb none_array = [None] * len(n.values) n.result.extend(none_array) else: - interm_request_output_values = request_output_values[0][i][0] - n.result.extend(interm_request_output_values) + n.result.extend(result.values[i]) def sort_fdb_request_ranges(self, current_start_idx, lat_length, fdb_node_ranges): (new_fdb_node_ranges, new_current_start_idx) = self.remove_duplicates_in_request_ranges( diff --git a/polytope_feature/datacube/backends/qubed.py b/polytope_feature/datacube/backends/qubed.py new file mode 100644 index 000000000..6d40fe266 --- /dev/null +++ b/polytope_feature/datacube/backends/qubed.py @@ -0,0 +1,385 @@ +import logging +import operator +from copy import deepcopy +from itertools import product +from ...utility.exceptions import BadGridError, BadRequestError, GribJumpNoIndexError +from ...utility.geometry import nearest_pt +import pygribjump as pygj +from qubed.value_types import QEnum +import numpy as np + +from .datacube import Datacube, TensorIndexTree + + +class QubedDatacube(Datacube): + + def __init__( + self, q, datacube_axes, datacube_transformations, config=None, axis_options=None, compressed_axes_options=[], alternative_axes=[], context=None + ): + if config is None: + config = {} + if axis_options is None: + axis_options = {} + + self.q = q + # TODO: find datacube_axes and datacube_transformations from options like other datacube backends + self.datacube_axes = datacube_axes + self.datacube_transformations = datacube_transformations + # TODO: find compressed_axes list + self.compressed_axes = [] + # TODO: should the gj object be passed in instead? + self.gj = pygj.GribJump() + + # TODO: this doesn't fill the axes as wanted + super().__init__(axis_options, compressed_axes_options) + + # TODO: where do these come from and are they right? + self.unwanted_path = {} + + # TODO: is this right? + + self.axis_options = axis_options + # Find values in the level 3 FDB datacube + + self.fdb_coordinates = {} + + # TODO: we instead now have a list of axes with the actual axes types... + # TODO: here use the qubed to find all axes names and then get the values from the first val of the qubed and then apply transformations to get the actual right axis type... + for axis_name in datacube_axes: + axis = datacube_axes[axis_name] + self.fdb_coordinates[axis_name] = [axis.type_eg] + + self.fdb_coordinates["values"] = [] + for name, values in self.fdb_coordinates.items(): + options = None + for opt in self.axis_options: + if opt.axis_name == name: + options = opt + + self._check_and_add_axes(options, name, values) + self.treated_axes.append(name) + self.complete_axes.append(name) + + # add other options to axis which were just created above like "lat" for the mapper transformations for eg + for name in self._axes: + if name not in self.treated_axes: + options = None + for opt in self.axis_options: + if opt.axis_name == name: + options = opt + + val = self._axes[name].type_eg + self._check_and_add_axes(options, name, val) + + def datacube_natural_indexes(self, qube_node): + if qube_node is not None: + return np.asarray(list(qube_node.values)) + else: + return [] + + def get_indices(self, path, path_node, axis, lower, upper, method=None): + """ + Given a path to a subset of the datacube, return the discrete indexes which exist between + two non-discrete values (lower, upper) for a particular axis (given by label) + If lower and upper are equal, returns the index which exactly matches that value (if it exists) + e.g. returns integer discrete points between two floats + """ + # path = self.fit_path(path) + indexes = axis.find_indexes_node(path_node, self, path) + + idx_between = axis.find_indices_between(indexes, lower, upper, self, method) + + logging.debug(f"For axis {axis.name} between {lower} and {upper}, found indices {idx_between}") + + return idx_between + + def get(self, requests, context=None): + if context is None: + context = {} + if len(requests.children) == 0: + return requests + fdb_requests = [] + fdb_requests_decoding_info = [] + self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) + + # here, loop through the fdb requests and request from gj and directly add to the nodes + complete_list_complete_uncompressed_requests = [] + complete_fdb_decoding_info = [] + for j, compressed_request in enumerate(fdb_requests): + uncompressed_request = {} + + # Need to determine the possible decompressed requests + + # find the possible combinations of compressed indices + interm_branch_tuple_values = [] + for key in compressed_request[0].keys(): + interm_branch_tuple_values.append(compressed_request[0][key]) + request_combis = product(*interm_branch_tuple_values) + + # Need to extract the possible requests and add them to the right nodes + for combi in request_combis: + uncompressed_request = {} + for i, key in enumerate(compressed_request[0].keys()): + uncompressed_request[key] = combi[i] + # TODO: get the hash from somewhere... + # self.grid_md5_hash = "cbda19e48d4d7e5e22641154878b9b22" + complete_uncompressed_request = (uncompressed_request, compressed_request[1], self.grid_md5_hash) + complete_list_complete_uncompressed_requests.append(complete_uncompressed_request) + complete_fdb_decoding_info.append(fdb_requests_decoding_info[j]) + + if logging.root.level <= logging.DEBUG: + printed_list_to_gj = complete_list_complete_uncompressed_requests[::1000] + logging.debug("The requests we give GribJump are: %s", printed_list_to_gj) + logging.info("Requests given to GribJump extract for %s", context) + try: + output_values = self.gj.extract(complete_list_complete_uncompressed_requests, context) + except Exception as e: + if "BadValue: Grid hash mismatch" in str(e): + logging.info("Error is: %s", e) + raise BadGridError() + if "Missing JumpInfo" in str(e): + logging.info("Error is: %s", e) + raise GribJumpNoIndexError() + else: + raise e + + logging.info("Requests extracted from GribJump for %s", context) + if logging.root.level <= logging.DEBUG: + printed_output_values = output_values[::1000] + logging.debug("GribJump outputs: %s", printed_output_values) + self.assign_fdb_output_to_nodes(output_values, complete_fdb_decoding_info) + + def get_fdb_requests( + self, + requests, + fdb_requests=[], + fdb_requests_decoding_info=[], + leaf_path=None, + ): + if leaf_path is None: + leaf_path = {} + + # First when request node is root, go to its children + if requests.key == "root": + logging.debug("Looking for data for the tree") + + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) + # If request node has no children, we have a leaf so need to assign fdb values to it + else: + key_value_path = {requests.key: requests.values} + ax = self._axes[requests.key] + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + # TODO: change to use the datacube trasnformations instead... + if requests.key == "time": + new_vals = [] + for val in key_value_path[requests.key]: + new_vals.append(val[7:9]+val[10:12]) + key_value_path[requests.key] = new_vals + if requests.key == "date": + new_vals = [] + for val in key_value_path[requests.key]: + new_vals.append(val[:4] + val[5:7] + val[8:10]) + key_value_path[requests.key] = new_vals + leaf_path.update(key_value_path) + if len(requests.children[0].children[0].children) == 0: + # find the fdb_requests and associated nodes to which to add results + (path, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values(requests, leaf_path) + ( + original_indices, + sorted_request_ranges, + fdb_node_ranges, + ) = self.sort_fdb_request_ranges(current_start_idxs, lat_length, fdb_node_ranges) + fdb_requests.append((path, sorted_request_ranges)) + fdb_requests_decoding_info.append((original_indices, fdb_node_ranges)) + + # Otherwise remap the path for this key and iterate again over children + else: + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path) + + def remove_duplicates_in_request_ranges(self, fdb_node_ranges, current_start_idxs): + # seen_indices = set() + # for i, idxs_list in enumerate(current_start_idxs): + # for k, sub_lat_idxs in enumerate(idxs_list): + # actual_fdb_node = fdb_node_ranges[i][k] + # original_fdb_node_range_vals = [] + # new_current_start_idx = [] + # for j, idx in enumerate(sub_lat_idxs): + # if idx not in seen_indices: + # # NOTE: need to remove it from the values in the corresponding tree node + # # NOTE: need to read just the range we give to gj + # original_fdb_node_range_vals.append(list(actual_fdb_node[0].values)[j]) + # seen_indices.add(idx) + # new_current_start_idx.append(idx) + # if original_fdb_node_range_vals != []: + # actual_fdb_node[0].values = tuple(original_fdb_node_range_vals) + # else: + # # there are no values on this node anymore so can remove it + # actual_fdb_node[0].remove_branch() + # if len(new_current_start_idx) == 0: + # current_start_idxs[i].pop(k) + # else: + # current_start_idxs[i][k] = new_current_start_idx + return (fdb_node_ranges, current_start_idxs) + + def nearest_lat_lon_search(self, requests): + if len(self.nearest_search) != 0: + first_ax_name = requests.children[0].key + second_ax_name = requests.children[0].children[0].key + + axes_in_nearest_search = [ + first_ax_name not in self.nearest_search.keys(), + second_ax_name not in self.nearest_search.keys(), + ] + + if all(not item for item in axes_in_nearest_search): + raise Exception("nearest point search axes are wrong") + + second_ax = self._axes[requests.children[0].children[0].key] + + nearest_pts = self.nearest_search.get((first_ax_name, second_ax_name), None) + if nearest_pts is None: + nearest_pts = self.nearest_search.get((second_ax_name, first_ax_name), None) + for i, pt in enumerate(nearest_pts): + nearest_pts[i] = [pt[1], pt[0]] + + transformed_nearest_pts = [] + for point in nearest_pts: + transformed_nearest_pts.append([point[0], second_ax._remap_val_to_axis_range(point[1])]) + + found_latlon_pts = [] + for lat_child in requests.children: + for lon_child in lat_child.children: + found_latlon_pts.append([lat_child.values, lon_child.values]) + + # now find the nearest lat lon to the points requested + nearest_latlons = [] + for pt in transformed_nearest_pts: + nearest_latlon = nearest_pt(found_latlon_pts, pt) + nearest_latlons.append(nearest_latlon) + + # need to remove the branches that do not fit + lat_children_values = [child.values for child in requests.children] + for i in range(len(lat_children_values)): + lat_child_val = lat_children_values[i] + lat_child = [child for child in requests.children if child.values == lat_child_val][0] + if lat_child.values not in [(latlon[0],) for latlon in nearest_latlons]: + lat_child.remove_branch() + else: + possible_lons = [latlon[1] for latlon in nearest_latlons if (latlon[0],) == lat_child.values] + lon_children_values = [child.values for child in lat_child.children] + for j in range(len(lon_children_values)): + lon_child_val = lon_children_values[j] + lon_child = [child for child in lat_child.children if child.values == lon_child_val][0] + for value in lon_child.values: + if value not in possible_lons: + lon_child.remove_compressed_branch(value) + + def get_2nd_last_values(self, requests, leaf_path=None): + if leaf_path is None: + leaf_path = {} + # In this function, we recursively loop over the last two layers of the tree and store the indices of the + # request ranges in those layers + self.nearest_lat_lon_search(requests) + + lat_length = len(requests.children) + current_start_idxs = [False] * lat_length + fdb_node_ranges = [False] * lat_length + for i in range(len(requests.children)): + lat_child = requests.children[i] + lon_length = len(lat_child.children) + current_start_idxs[i] = [None] * lon_length + fdb_node_ranges[i] = [[TensorIndexTree.root for y in range(lon_length)] for x in range(lon_length)] + current_start_idx = deepcopy(current_start_idxs[i]) + fdb_range_nodes = deepcopy(fdb_node_ranges[i]) + key_value_path = {lat_child.key: list(lat_child.values)} + ax = self._axes[lat_child.key] + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + (current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( + lat_child, leaf_path, current_start_idx, fdb_range_nodes + ) + + leaf_path_copy = deepcopy(leaf_path) + leaf_path_copy.pop("values", None) + return (leaf_path_copy, current_start_idxs, fdb_node_ranges, lat_length) + + def get_last_layer_before_leaf(self, requests, leaf_path, current_idx, fdb_range_n): + current_idx = [[] for i in range(len(requests.children))] + fdb_range_n = [[] for i in range(len(requests.children))] + for i, c in enumerate(requests.children): + # now c are the leaves of the initial tree + key_value_path = {c.key: list(c.values)} + ax = self._axes[c.key] + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + # TODO: change this to accommodate non consecutive indexes being compressed too + current_idx[i].extend(key_value_path["values"]) + fdb_range_n[i].append(c) + return (current_idx, fdb_range_n) + + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + for k, request_output_values in enumerate(output_values): + ( + original_indices, + fdb_node_ranges, + ) = fdb_requests_decoding_info[k] + sorted_fdb_range_nodes = [fdb_node_ranges[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + n = sorted_fdb_range_nodes[i][0] + if len(request_output_values.values) == 0: + # If we are here, no data was found for this path in the fdb + none_array = [None] * len(n.values) + if n.data.metadata.get("result", None) is None: + n.data.metadata["result"] = [] + n.data.metadata["result"].extend(none_array) + else: + if n.data.metadata.get("result", None) is None: + n.data.metadata["result"] = [] + n.data.metadata["result"].extend(request_output_values.values[i]) + + def sort_fdb_request_ranges(self, current_start_idx, lat_length, fdb_node_ranges): + (new_fdb_node_ranges, new_current_start_idx) = self.remove_duplicates_in_request_ranges( + fdb_node_ranges, current_start_idx + ) + interm_request_ranges = [] + # TODO: modify the start indexes to have as many arrays as the request ranges + new_fdb_node_ranges = [] + for i in range(lat_length): + interm_fdb_nodes = fdb_node_ranges[i] + old_interm_start_idx = current_start_idx[i] + for j in range(len(old_interm_start_idx)): + # TODO: if we sorted the cyclic values in increasing order on the tree too, + # then we wouldn't have to sort here? + sorted_list = sorted(enumerate(old_interm_start_idx[j]), key=lambda x: x[1]) + original_indices_idx, interm_start_idx = zip(*sorted_list) + for interm_fdb_nodes_obj in interm_fdb_nodes[j]: + interm_fdb_nodes_obj.data.values = QEnum(tuple([list(interm_fdb_nodes_obj.values)[k] + for k in original_indices_idx])) + if abs(interm_start_idx[-1] + 1 - interm_start_idx[0]) <= len(interm_start_idx): + current_request_ranges = (interm_start_idx[0], interm_start_idx[-1] + 1) + interm_request_ranges.append(current_request_ranges) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + else: + jumps = list(map(operator.sub, interm_start_idx[1:], interm_start_idx[:-1])) + last_idx = 0 + for k, jump in enumerate(jumps): + if jump > 1: + current_request_ranges = (interm_start_idx[last_idx], interm_start_idx[k] + 1) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + last_idx = k + 1 + interm_request_ranges.append(current_request_ranges) + if k == len(interm_start_idx) - 2: + current_request_ranges = (interm_start_idx[last_idx], interm_start_idx[-1] + 1) + interm_request_ranges.append(current_request_ranges) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + request_ranges_with_idx = list(enumerate(interm_request_ranges)) + sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) + original_indices, sorted_request_ranges = zip(*sorted_list) + return (original_indices, sorted_request_ranges, new_fdb_node_ranges) diff --git a/polytope_feature/datacube/backends/qubed_fdb.py b/polytope_feature/datacube/backends/qubed_fdb.py new file mode 100644 index 000000000..dcd8b31fa --- /dev/null +++ b/polytope_feature/datacube/backends/qubed_fdb.py @@ -0,0 +1,403 @@ +import logging +import operator +from copy import deepcopy +from itertools import product + +from ...utility.exceptions import BadGridError, BadRequestError +from ...utility.geometry import nearest_pt +from .datacube import Datacube, TensorIndexTree +from qubed import Qube +import requests + + +class FDBDatacube(Datacube): + def __init__( + self, gj, config=None, axis_options=None, compressed_axes_options=[], alternative_axes=[], context=None + ): + if config is None: + config = {} + if context is None: + context = {} + + super().__init__(axis_options, compressed_axes_options) + + logging.info("Created an FDB datacube with options: " + str(axis_options)) + + self.unwanted_path = {} + self.axis_options = axis_options + + # partial_request = config + # Find values in the level 3 FDB datacube + + self.gj = gj + self.fdb_tree = Qube.from_json(requests.get( + "https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate_dt.json").json()) + + if len(alternative_axes) == 0: + logging.info("Find GribJump axes for %s", context) + # TODO: change to get the axes from + # self.fdb_coordinates = self.gj.axes(partial_request, ctx=context) + self.fdb_coordinates = self.fdb_tree.axes() + logging.info("Retrieved available GribJump axes for %s", context) + if len(self.fdb_coordinates) == 0: + raise BadRequestError({}) + else: + self.fdb_coordinates = {} + for axis_config in alternative_axes: + self.fdb_coordinates[axis_config.axis_name] = axis_config.values + + fdb_coordinates_copy = deepcopy(self.fdb_coordinates) + for axis, vals in fdb_coordinates_copy.items(): + if len(vals) == 1: + if vals[0] == "": + self.fdb_coordinates.pop(axis) + + logging.info("Axes returned from GribJump are: " + str(self.fdb_coordinates)) + + self.fdb_coordinates["values"] = [] + for name, values in self.fdb_coordinates.items(): + values.sort() + options = None + for opt in self.axis_options: + if opt.axis_name == name: + options = opt + + self._check_and_add_axes(options, name, values) + self.treated_axes.append(name) + self.complete_axes.append(name) + + # add other options to axis which were just created above like "lat" for the mapper transformations for eg + for name in self._axes: + if name not in self.treated_axes: + options = None + for opt in self.axis_options: + if opt.axis_name == name: + options = opt + + val = self._axes[name].type + self._check_and_add_axes(options, name, val) + + logging.info("Polytope created axes for %s", self._axes.keys()) + + # def check_branching_axes(self, request): + # polytopes = request.polytopes() + # for polytope in polytopes: + # for ax in polytope._axes: + # if ax == "levtype": + # (upper, lower, idx) = polytope.extents(ax) + # if "sfc" in polytope.points[idx]: + # self.fdb_coordinates.pop("levelist", None) + + # if ax == "param": + # (upper, lower, idx) = polytope.extents(ax) + # if "140251" not in polytope.points[idx]: + # self.fdb_coordinates.pop("direction", None) + # self.fdb_coordinates.pop("frequency", None) + # else: + # # special param with direction and frequency + # if len(polytope.points[idx]) > 1: + # raise ValueError( + # "Param 251 is part of a special branching of the datacube. Please request it separately." # noqa: E501 + # ) + # self.fdb_coordinates.pop("quantile", None) + # self.fdb_coordinates.pop("year", None) + # self.fdb_coordinates.pop("month", None) + + # # NOTE: verify that we also remove the axis object for axes we've removed here + # axes_to_remove = set(self.complete_axes) - set(self.fdb_coordinates.keys()) + + # # Remove the keys from self._axes + # for axis_name in axes_to_remove: + # self._axes.pop(axis_name, None) + + def get(self, requests: TensorIndexTree, context=None): + if context is None: + context = {} + if len(requests.children) == 0: + return requests + fdb_requests = [] + fdb_requests_decoding_info = [] + self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) + + # here, loop through the fdb requests and request from gj and directly add to the nodes + complete_list_complete_uncompressed_requests = [] + complete_fdb_decoding_info = [] + for j, compressed_request in enumerate(fdb_requests): + uncompressed_request = {} + + # Need to determine the possible decompressed requests + + # find the possible combinations of compressed indices + interm_branch_tuple_values = [] + for key in compressed_request[0].keys(): + interm_branch_tuple_values.append(compressed_request[0][key]) + request_combis = product(*interm_branch_tuple_values) + + # Need to extract the possible requests and add them to the right nodes + for combi in request_combis: + uncompressed_request = {} + for i, key in enumerate(compressed_request[0].keys()): + uncompressed_request[key] = combi[i] + complete_uncompressed_request = (uncompressed_request, compressed_request[1], self.grid_md5_hash) + complete_list_complete_uncompressed_requests.append(complete_uncompressed_request) + complete_fdb_decoding_info.append(fdb_requests_decoding_info[j]) + + if logging.root.level <= logging.DEBUG: + printed_list_to_gj = complete_list_complete_uncompressed_requests[::1000] + logging.debug("The requests we give GribJump are: %s", printed_list_to_gj) + logging.info("Requests given to GribJump extract for %s", context) + try: + output_values = self.gj.extract(complete_list_complete_uncompressed_requests, context) + except Exception as e: + if "BadValue: Grid hash mismatch" in str(e): + logging.info("Error is: %s", e) + raise BadGridError() + else: + raise e + + logging.info("Requests extracted from GribJump for %s", context) + if logging.root.level <= logging.DEBUG: + printed_output_values = output_values[::1000] + logging.debug("GribJump outputs: %s", printed_output_values) + self.assign_fdb_output_to_nodes(output_values, complete_fdb_decoding_info) + + def get_fdb_requests( + self, + requests: TensorIndexTree, + fdb_requests=[], + fdb_requests_decoding_info=[], + leaf_path=None, + ): + if leaf_path is None: + leaf_path = {} + + # First when request node is root, go to its children + if requests.axis.name == "root": + logging.debug("Looking for data for the tree") + + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) + # If request node has no children, we have a leaf so need to assign fdb values to it + else: + key_value_path = {requests.axis.name: requests.values} + ax = requests.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + if len(requests.children[0].children[0].children) == 0: + # find the fdb_requests and associated nodes to which to add results + (path, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values(requests, leaf_path) + ( + original_indices, + sorted_request_ranges, + fdb_node_ranges, + ) = self.sort_fdb_request_ranges(current_start_idxs, lat_length, fdb_node_ranges) + fdb_requests.append((path, sorted_request_ranges)) + fdb_requests_decoding_info.append((original_indices, fdb_node_ranges)) + + # Otherwise remap the path for this key and iterate again over children + else: + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path) + + def remove_duplicates_in_request_ranges(self, fdb_node_ranges, current_start_idxs): + seen_indices = set() + for i, idxs_list in enumerate(current_start_idxs): + for k, sub_lat_idxs in enumerate(idxs_list): + actual_fdb_node = fdb_node_ranges[i][k] + original_fdb_node_range_vals = [] + new_current_start_idx = [] + for j, idx in enumerate(sub_lat_idxs): + if idx not in seen_indices: + # NOTE: need to remove it from the values in the corresponding tree node + # NOTE: need to read just the range we give to gj + original_fdb_node_range_vals.append(actual_fdb_node[0].values[j]) + seen_indices.add(idx) + new_current_start_idx.append(idx) + if original_fdb_node_range_vals != []: + actual_fdb_node[0].values = tuple(original_fdb_node_range_vals) + else: + # there are no values on this node anymore so can remove it + actual_fdb_node[0].remove_branch() + if len(new_current_start_idx) == 0: + current_start_idxs[i].pop(k) + else: + current_start_idxs[i][k] = new_current_start_idx + return (fdb_node_ranges, current_start_idxs) + + def nearest_lat_lon_search(self, requests): + if len(self.nearest_search) != 0: + first_ax_name = requests.children[0].axis.name + second_ax_name = requests.children[0].children[0].axis.name + + if first_ax_name not in self.nearest_search.keys() or second_ax_name not in self.nearest_search.keys(): + raise Exception("nearest point search axes are wrong") + + second_ax = requests.children[0].children[0].axis + + nearest_pts = [ + [lat_val, second_ax._remap_val_to_axis_range(lon_val)] + for (lat_val, lon_val) in zip( + self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0] + ) + ] + + found_latlon_pts = [] + for lat_child in requests.children: + for lon_child in lat_child.children: + found_latlon_pts.append([lat_child.values, lon_child.values]) + + # now find the nearest lat lon to the points requested + nearest_latlons = [] + for pt in nearest_pts: + nearest_latlon = nearest_pt(found_latlon_pts, pt) + nearest_latlons.append(nearest_latlon) + + # need to remove the branches that do not fit + lat_children_values = [child.values for child in requests.children] + for i in range(len(lat_children_values)): + lat_child_val = lat_children_values[i] + lat_child = [child for child in requests.children if child.values == lat_child_val][0] + if lat_child.values not in [(latlon[0],) for latlon in nearest_latlons]: + lat_child.remove_branch() + else: + possible_lons = [latlon[1] for latlon in nearest_latlons if (latlon[0],) == lat_child.values] + lon_children_values = [child.values for child in lat_child.children] + for j in range(len(lon_children_values)): + lon_child_val = lon_children_values[j] + lon_child = [child for child in lat_child.children if child.values == lon_child_val][0] + for value in lon_child.values: + if value not in possible_lons: + lon_child.remove_compressed_branch(value) + + def get_2nd_last_values(self, requests, leaf_path=None): + if leaf_path is None: + leaf_path = {} + # In this function, we recursively loop over the last two layers of the tree and store the indices of the + # request ranges in those layers + self.nearest_lat_lon_search(requests) + + lat_length = len(requests.children) + current_start_idxs = [False] * lat_length + fdb_node_ranges = [False] * lat_length + for i in range(len(requests.children)): + lat_child = requests.children[i] + lon_length = len(lat_child.children) + current_start_idxs[i] = [None] * lon_length + fdb_node_ranges[i] = [[TensorIndexTree.root for y in range(lon_length)] for x in range(lon_length)] + current_start_idx = deepcopy(current_start_idxs[i]) + fdb_range_nodes = deepcopy(fdb_node_ranges[i]) + key_value_path = {lat_child.axis.name: lat_child.values} + ax = lat_child.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + (current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( + lat_child, leaf_path, current_start_idx, fdb_range_nodes + ) + + leaf_path_copy = deepcopy(leaf_path) + leaf_path_copy.pop("values", None) + return (leaf_path_copy, current_start_idxs, fdb_node_ranges, lat_length) + + def get_last_layer_before_leaf(self, requests, leaf_path, current_idx, fdb_range_n): + current_idx = [[] for i in range(len(requests.children))] + fdb_range_n = [[] for i in range(len(requests.children))] + for i, c in enumerate(requests.children): + # now c are the leaves of the initial tree + key_value_path = {c.axis.name: c.values} + ax = c.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + # TODO: change this to accommodate non consecutive indexes being compressed too + current_idx[i].extend(key_value_path["values"]) + fdb_range_n[i].append(c) + return (current_idx, fdb_range_n) + + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + for k in range(len(output_values)): + request_output_values = output_values[k] + ( + original_indices, + fdb_node_ranges, + ) = fdb_requests_decoding_info[k] + sorted_fdb_range_nodes = [fdb_node_ranges[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + n = sorted_fdb_range_nodes[i][0] + if len(request_output_values[0]) == 0: + # If we are here, no data was found for this path in the fdb + none_array = [None] * len(n.values) + n.result.extend(none_array) + else: + interm_request_output_values = request_output_values[0][i][0] + n.result.extend(interm_request_output_values) + + def sort_fdb_request_ranges(self, current_start_idx, lat_length, fdb_node_ranges): + (new_fdb_node_ranges, new_current_start_idx) = self.remove_duplicates_in_request_ranges( + fdb_node_ranges, current_start_idx + ) + interm_request_ranges = [] + # TODO: modify the start indexes to have as many arrays as the request ranges + new_fdb_node_ranges = [] + for i in range(lat_length): + interm_fdb_nodes = fdb_node_ranges[i] + old_interm_start_idx = current_start_idx[i] + for j in range(len(old_interm_start_idx)): + # TODO: if we sorted the cyclic values in increasing order on the tree too, + # then we wouldn't have to sort here? + sorted_list = sorted(enumerate(old_interm_start_idx[j]), key=lambda x: x[1]) + original_indices_idx, interm_start_idx = zip(*sorted_list) + for interm_fdb_nodes_obj in interm_fdb_nodes[j]: + interm_fdb_nodes_obj.values = tuple([interm_fdb_nodes_obj.values[k] for k in original_indices_idx]) + if abs(interm_start_idx[-1] + 1 - interm_start_idx[0]) <= len(interm_start_idx): + current_request_ranges = (interm_start_idx[0], interm_start_idx[-1] + 1) + interm_request_ranges.append(current_request_ranges) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + else: + jumps = list(map(operator.sub, interm_start_idx[1:], interm_start_idx[:-1])) + last_idx = 0 + for k, jump in enumerate(jumps): + if jump > 1: + current_request_ranges = (interm_start_idx[last_idx], interm_start_idx[k] + 1) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + last_idx = k + 1 + interm_request_ranges.append(current_request_ranges) + if k == len(interm_start_idx) - 2: + current_request_ranges = (interm_start_idx[last_idx], interm_start_idx[-1] + 1) + interm_request_ranges.append(current_request_ranges) + new_fdb_node_ranges.append(interm_fdb_nodes[j]) + request_ranges_with_idx = list(enumerate(interm_request_ranges)) + sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) + original_indices, sorted_request_ranges = zip(*sorted_list) + return (original_indices, sorted_request_ranges, new_fdb_node_ranges) + + def datacube_natural_indexes(self, axis, subarray): + indexes = subarray.get(axis.name, None) + return indexes + + def select(self, path, unmapped_path): + return self.fdb_coordinates + + def ax_vals(self, name): + return self.fdb_coordinates.get(name, None) + + def prep_tree_encoding(self, node, unwanted_path=None): + # TODO: prepare the tree for protobuf encoding + # ie transform all axes for gribjump and adding the index property on the leaves + if unwanted_path is None: + unwanted_path = {} + + ax = node.axis + (new_node, unwanted_path) = ax.unmap_tree_node(node, unwanted_path) + + if len(node.children) != 0: + for c in new_node.children: + self.prep_tree_encoding(c, unwanted_path) + + def prep_tree_decoding(self, tree): + # TODO: transform the tree after decoding from protobuf + # ie unstransform all axes from gribjump and put the indexes back as a leaf/extra node + pass diff --git a/polytope_feature/datacube/datacube_axis.py b/polytope_feature/datacube/datacube_axis.py index fbf8a70cb..b5b4869f5 100644 --- a/polytope_feature/datacube/datacube_axis.py +++ b/polytope_feature/datacube/datacube_axis.py @@ -78,6 +78,22 @@ def find_indexes(self, path, datacube): indexes = transformation.find_modified_indexes(indexes, path, datacube, self) return indexes + def find_standard_indexes_node(self, path_node, datacube): + return datacube.datacube_natural_indexes(path_node) + + def find_indexes_node(self, path_node, datacube, path): + indexes = self.find_standard_indexes_node(path_node, datacube) + # path = {self.name: tuple(path_node.values)} + if not path: + if path_node: + path = {path_node.key: tuple(path_node.values)} + else: + path = {self.name: tuple()} + for transformation in self.transformations[::-1]: + indexes = transformation.find_modified_indexes(indexes, path, datacube, self) + # print(indexes) + return indexes + def offset(self, value): offset = 0 for transformation in self.transformations[::-1]: @@ -196,6 +212,7 @@ def __init__(self): # TODO: Maybe here, store transformations as a dico instead self.transformations = [] self.type = 0 + self.type_eg = 0 self.can_round = True def parse(self, value: Any) -> Any: @@ -218,6 +235,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = 0.0 + self.type_eg = 0.0 self.can_round = True def parse(self, value: Any) -> Any: @@ -240,6 +258,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = pd.Timestamp("2000-01-01T00:00:00") + self.type_eg = "20000101T000000" self.can_round = False def parse(self, value: Any) -> Any: @@ -270,6 +289,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = np.timedelta64(0, "s") + self.type_eg = "0000" self.can_round = False def parse(self, value: Any) -> Any: @@ -300,6 +320,8 @@ def __init__(self): self.range = None self.transformations = [] self.can_round = False + self.type = "" + self.type_eg = "" def parse(self, value: Any) -> Any: return value diff --git a/polytope_feature/datacube/transformations/datacube_mappers/datacube_mappers.py b/polytope_feature/datacube/transformations/datacube_mappers/datacube_mappers.py index c5a2b551c..e06c8fd38 100644 --- a/polytope_feature/datacube/transformations/datacube_mappers/datacube_mappers.py +++ b/polytope_feature/datacube/transformations/datacube_mappers/datacube_mappers.py @@ -106,6 +106,8 @@ def unmap_path_key(self, key_value_path, leaf_path, unwanted_path, axis): if axis.name == self._mapped_axes()[1]: first_val = unwanted_path[self._mapped_axes()[0]] # unmapped_idx = [self.unmap(first_val, (val,)) for val in value] + # print("AND HERE??") + # print(values) unmapped_idx = self.unmap(first_val, values) leaf_path.pop(self._mapped_axes()[0], None) key_value_path.pop(axis.name) diff --git a/polytope_feature/datacube/transformations/datacube_mappers/mapper_types/healpix_nested.py b/polytope_feature/datacube/transformations/datacube_mappers/mapper_types/healpix_nested.py index da3f0f1ae..fe695469a 100644 --- a/polytope_feature/datacube/transformations/datacube_mappers/mapper_types/healpix_nested.py +++ b/polytope_feature/datacube/transformations/datacube_mappers/mapper_types/healpix_nested.py @@ -1,3 +1,4 @@ +import bisect import math from ..datacube_mappers import DatacubeMapper @@ -13,11 +14,9 @@ def __init__(self, base_axis, mapped_axes, resolution, md5_hash=None, local_area self._first_axis_vals = self.first_axis_vals() self.compressed_grid_axes = [self._mapped_axes[1]] self.Nside = self._resolution - self._cached_longitudes = {} self.k = int(math.log2(self.Nside)) self.Npix = 12 * self.Nside * self.Nside self.Ncap = (self.Nside * (self.Nside - 1)) << 1 - self._healpix_longitudes = {} if md5_hash is not None: self.md5_hash = md5_hash else: @@ -59,9 +58,7 @@ def second_axis_vals(self, first_val): return values def second_axis_vals_from_idx(self, first_val_idx): - if first_val_idx not in self._healpix_longitudes: - self._healpix_longitudes[first_val_idx] = self.HEALPix_longitudes(first_val_idx) - values = self._healpix_longitudes[first_val_idx] + values = self.HEALPix_longitudes(first_val_idx) return values def HEALPix_nj(self, i): @@ -77,19 +74,14 @@ def HEALPix_nj(self, i): return self.HEALPix_nj(ni - 1 - i) def HEALPix_longitudes(self, i): - if i in self._cached_longitudes: - return self._cached_longitudes[i] - else: - Nj = self.HEALPix_nj(i) - step = 360.0 / Nj - start = ( - step / 2.0 - if i < self._resolution or 3 * self._resolution - 1 < i or (i + self._resolution) % 2 - else 0.0 - ) - - longitudes = [start + n * step for n in range(Nj)] - self._cached_longitudes[i] = longitudes + Nj = self.HEALPix_nj(i) + step = 360.0 / Nj + start = ( + step / 2.0 if i < self._resolution or 3 * self._resolution - 1 < i or (i + self._resolution) % 2 else 0.0 + ) + + longitudes = [start + n * step for n in range(Nj)] + return longitudes def map_second_axis(self, first_val, lower, upper): @@ -113,31 +105,51 @@ def axes_idx_to_healpix_idx(self, first_idx, second_idx): return idx for i in range(3 * self._resolution, 4 * self._resolution - 1): if i != first_idx: - idx += 4 * (4 * self._resolution - 1 - i) + idx += 4 * (4 * self._resolution - 1 - i + 1) else: idx += second_idx return idx - def unmap(self, first_val, second_vals): + def find_second_idx(self, first_val, second_val): + tol = 1e-10 + second_axis_vals = self.second_axis_vals(first_val) + second_idx = bisect.bisect_left(second_axis_vals, second_val - tol) + return second_idx + + def unmap_first_val_to_start_line_idx(self, first_val): tol = 1e-8 - first_idx = next( - (i for i, val in enumerate(self._first_axis_vals) if first_val[0] - tol <= val <= first_val[0] + tol), None - ) - if first_idx is None: - return None - second_axis_vals = self.second_axis_vals_from_idx(first_idx) + first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_idx = self._first_axis_vals.index(first_val) + idx = 0 + for i in range(self._resolution - 1): + if i != first_idx: + idx += 4 * (i + 1) + else: + return idx + for i in range(self._resolution - 1, 3 * self._resolution): + if i != first_idx: + idx += 4 * self._resolution + else: + return idx + for i in range(3 * self._resolution, 4 * self._resolution - 1): + if i != first_idx: + idx += 4 * (4 * self._resolution - 1 - i + 1) + else: + return idx - return_idxs = [] + def unmap(self, first_val, second_vals, unmapped_idx=None): + tol = 1e-8 + first_value = [i for i in self._first_axis_vals if first_val[0] - tol <= i <= first_val[0] + tol][0] + first_idx = self._first_axis_vals.index(first_value) + healpix_idxs = [] for second_val in second_vals: - second_idx = next( - (i for i, val in enumerate(second_axis_vals) if second_val - tol <= val <= second_val + tol), None - ) - if second_idx is None: - return None + second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0] + second_idx = self.second_axis_vals(first_val).index(second_val) healpix_index = self.axes_idx_to_healpix_idx(first_idx, second_idx) - nested_healpix_index = self.ring_to_nested(healpix_index) - return_idxs.append(nested_healpix_index) - return return_idxs + # TODO: here do conversion of ring to nested healpix representation before returning + healpix_index = self.ring_to_nested(healpix_index) + healpix_idxs.append(healpix_index) + return healpix_idxs def div_03(self, a, b): t = 1 if a >= (b << 1) else 0 diff --git a/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py b/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py index baa38009d..4d3232910 100644 --- a/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py +++ b/polytope_feature/datacube/transformations/datacube_reverse/datacube_reverse.py @@ -1,5 +1,6 @@ from ....utility.list_tools import bisect_left_cmp, bisect_right_cmp from ..datacube_transformations import DatacubeAxisTransformation +import numpy as np class DatacubeAxisReverse(DatacubeAxisTransformation): @@ -24,12 +25,17 @@ def unwanted_axes(self): def find_modified_indexes(self, indexes, path, datacube, axis): if axis.name in datacube.complete_axes: + # if isinstance(indexes, list): + # indexes.sort() + # ordered_indices = indexes + # else: ordered_indices = indexes.sort_values() else: ordered_indices = indexes return ordered_indices def find_indices_between(self, indexes, low, up, datacube, method, indexes_between_ranges, axis): + # indexes = np.asarray(indexes) indexes_between_ranges = [] if axis.name == self.name: if axis.name in datacube.complete_axes: diff --git a/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py b/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py index 7ea518c05..2a96f4103 100644 --- a/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py +++ b/polytope_feature/datacube/transformations/datacube_type_change/datacube_type_change.py @@ -84,14 +84,17 @@ def __init__(self, axis_name, new_type): def transform_type(self, value): try: - return pd.Timestamp(value) + return pd.Timestamp(str(value)) except ValueError: return None def make_str(self, value): values = [] for val in value: - values.append(val.strftime("%Y%m%d")) + if isinstance(val, str): + values.append(val) + else: + values.append(val.strftime("%Y%m%d")) return tuple(values) @@ -111,9 +114,12 @@ def transform_type(self, value): def make_str(self, value): values = [] for val in value: - hours = int(val.total_seconds() // 3600) - mins = int((val.total_seconds() % 3600) // 60) - values.append(f"{hours:02d}{mins:02d}") + if isinstance(val, str): + values.append(val) + else: + hours = int(val.total_seconds() // 3600) + mins = int((val.total_seconds() % 3600) // 60) + values.append(f"{hours:02d}{mins:02d}") return tuple(values) diff --git a/polytope_feature/engine/engine.py b/polytope_feature/engine/engine.py index c714db0b2..d44facc31 100644 --- a/polytope_feature/engine/engine.py +++ b/polytope_feature/engine/engine.py @@ -2,7 +2,10 @@ from ..datacube.backends.datacube import Datacube from ..datacube.tensor_index_tree import TensorIndexTree -from ..shapes import ConvexPolytope +from ..shapes import ConvexPolytope, Product + +from ..datacube.datacube_axis import UnsliceableDatacubeAxis +from ..utility.list_tools import unique class Engine: @@ -17,3 +20,46 @@ def default(): from .hullslicer import HullSlicer return HullSlicer() + + def _unique_continuous_points(self, p: ConvexPolytope, datacube: Datacube): + for i, ax in enumerate(p._axes): + mapper = datacube.get_mapper(ax) + if self.ax_is_unsliceable.get(ax, None) is None: + self.ax_is_unsliceable[ax] = isinstance(mapper, UnsliceableDatacubeAxis) + if self.ax_is_unsliceable[ax]: + break + for j, val in enumerate(p.points): + p.points[j][i] = mapper.to_float(mapper.parse(p.points[j][i])) + # Remove duplicate points + unique(p.points) + + def pre_process_polytopes(self, datacube, polytopes): + for p in polytopes: + if isinstance(p, Product): + for poly in p.polytope(): + self._unique_continuous_points(poly, datacube) + else: + self._unique_continuous_points(p, datacube) + + def find_compressed_axes(self, datacube, polytopes): + # First determine compressable axes from input polytopes + compressable_axes = [] + for polytope in polytopes: + if polytope.is_orthogonal: + for ax in polytope.axes(): + compressable_axes.append(ax) + # Cross check this list with list of compressable axis from datacube + # (should not include any merged or coupled axes) + for compressed_axis in compressable_axes: + if compressed_axis in datacube.compressed_axes: + self.compressed_axes.append(compressed_axis) + # add the last axis of the grid always (longitude) as a compressed axis + k, last_value = _, datacube.axes[k] = datacube.axes.popitem() + self.compressed_axes.append(k) + + def remove_compressed_axis_in_union(self, polytopes): + for p in polytopes: + if p.is_in_union: + for axis in p.axes(): + if axis == self.compressed_axes[-1]: + self.compressed_axes.remove(axis) diff --git a/polytope_feature/engine/hullslicer.py b/polytope_feature/engine/hullslicer.py index efbdd21df..64b50b640 100644 --- a/polytope_feature/engine/hullslicer.py +++ b/polytope_feature/engine/hullslicer.py @@ -6,13 +6,12 @@ import scipy.spatial from ..datacube.backends.datacube import Datacube -from ..datacube.datacube_axis import UnsliceableDatacubeAxis from ..datacube.tensor_index_tree import TensorIndexTree from ..shapes import ConvexPolytope, Product from ..utility.combinatorics import group, tensor_product from ..utility.exceptions import UnsliceableShapeError from ..utility.geometry import lerp -from ..utility.list_tools import argmax, argmin, unique +from ..utility.list_tools import argmax, argmin from .engine import Engine @@ -25,18 +24,6 @@ def __init__(self): self.remapped_vals = {} self.compressed_axes = [] - def _unique_continuous_points(self, p: ConvexPolytope, datacube: Datacube): - for i, ax in enumerate(p._axes): - mapper = datacube.get_mapper(ax) - if self.ax_is_unsliceable.get(ax, None) is None: - self.ax_is_unsliceable[ax] = isinstance(mapper, UnsliceableDatacubeAxis) - if self.ax_is_unsliceable[ax]: - break - for j, val in enumerate(p.points): - p.points[j][i] = mapper.to_float(mapper.parse(p.points[j][i])) - # Remove duplicate points - unique(p.points) - def _build_unsliceable_child(self, polytope, ax, node, datacube, lowers, next_nodes, slice_axis_idx): if not polytope.is_flat: raise UnsliceableShapeError(ax) @@ -180,29 +167,6 @@ def _build_branch(self, ax, node, datacube, next_nodes): del node["unsliced_polytopes"] - def find_compressed_axes(self, datacube, polytopes): - # First determine compressable axes from input polytopes - compressable_axes = [] - for polytope in polytopes: - if polytope.is_orthogonal: - for ax in polytope.axes(): - compressable_axes.append(ax) - # Cross check this list with list of compressable axis from datacube - # (should not include any merged or coupled axes) - for compressed_axis in compressable_axes: - if compressed_axis in datacube.compressed_axes: - self.compressed_axes.append(compressed_axis) - # add the last axis of the grid always (longitude) as a compressed axis - k, last_value = _, datacube.axes[k] = datacube.axes.popitem() - self.compressed_axes.append(k) - - def remove_compressed_axis_in_union(self, polytopes): - for p in polytopes: - if p.is_in_union: - for axis in p.axes(): - if axis == self.compressed_axes[-1]: - self.compressed_axes.remove(axis) - def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): # Determine list of axes to compress self.find_compressed_axes(datacube, polytopes) @@ -211,12 +175,7 @@ def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): self.remove_compressed_axis_in_union(polytopes) # Convert the polytope points to float type to support triangulation and interpolation - for p in polytopes: - if isinstance(p, Product): - for poly in p.polytope(): - self._unique_continuous_points(poly, datacube) - else: - self._unique_continuous_points(p, datacube) + self.pre_process_polytopes(datacube, polytopes) groups, input_axes = group(polytopes) datacube.validate(input_axes) @@ -243,7 +202,6 @@ def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): final_polys.extend(poly.polytope()) else: final_polys.append(poly) - # r["unsliced_polytopes"] = set(new_c) r["unsliced_polytopes"] = set(final_polys) current_nodes = [r] for ax in datacube.axes.values(): diff --git a/polytope_feature/engine/qubed_slicer.py b/polytope_feature/engine/qubed_slicer.py new file mode 100644 index 000000000..57df7dfbc --- /dev/null +++ b/polytope_feature/engine/qubed_slicer.py @@ -0,0 +1,338 @@ + +from qubed import Qube +from qubed.value_types import QEnum +from qubed.set_operations import union +from .hullslicer import slice +from ..datacube.backends.qubed import QubedDatacube +from .engine import Engine +import pandas as pd +from ..datacube.datacube_axis import UnsliceableDatacubeAxis +from ..datacube.transformations.datacube_mappers.datacube_mappers import DatacubeMapper +from ..shapes import ConvexPolytope, Product +from ..utility.combinatorics import group, tensor_product +from typing import List + +from ..datacube.backends.datacube import Datacube +import math + + +class QubedSlicer(Engine): + def __init__(self): + self.ax_is_unsliceable = {} + self.compressed_axes = [] + self.remapped_vals = {} + + def find_datacube_vals(): + # TODO + pass + + def find_values_between(self, polytope, ax, node, datacube, lower, upper, path=None): + if isinstance(ax, UnsliceableDatacubeAxis): + return [v for v in node.values if lower <= v <= upper] + + tol = ax.tol + lower = ax.from_float(lower - tol) + upper = ax.from_float(upper + tol) + + method = polytope.method + + # values = datacube.get_indices(flattened, ax, lower, upper, method) + values = datacube.get_indices(path, node, ax, lower, upper, method) + return values + + def remap_values(self, ax, value): + remapped_val = self.remapped_vals.get((value, ax.name), None) + if remapped_val is None: + remapped_val = value + if ax.is_cyclic: + remapped_val_interm = ax.remap([value, value])[0] + remapped_val = (remapped_val_interm[0] + remapped_val_interm[1]) / 2 + if ax.can_round: + remapped_val = round(remapped_val, int(-math.log10(ax.tol))) + self.remapped_vals[(value, ax.name)] = remapped_val + return remapped_val + + def _actual_slice(self, q: Qube, polytopes_to_slice, datacube, datacube_transformations) -> 'Qube': + + def find_polytopes_on_axis(axis_name, polytopes): + polytopes_on_axis = [] + for poly in polytopes: + if axis_name in poly._axes: + polytopes_on_axis.append(poly) + return polytopes_on_axis + + def change_datacube_val_types(child: Qube, datacube_transformations): + axis_name = child.key + transformation = datacube_transformations.get(axis_name, None) + child_vals = child.values + + # TODO: use axis.find_indexes_between to find the right child_vals + # TODO: actually, build same as find_values_between(self, polytope, ax, node, datacube, lower, upper) by writing new functions in qubed backend + new_vals = [] + for val in child_vals: + if transformation: + new_vals.append(transformation.transform_type(val)) + else: + new_vals.append(val) + + return new_vals + + def transform_upper_lower(axis_name, lower, upper, datacube): + ax = datacube._axes[axis_name] + if isinstance(ax, UnsliceableDatacubeAxis): + return (lower, upper) + tol = ax.tol + lower = ax.from_float(lower - tol) + upper = ax.from_float(upper + tol) + + return (lower, upper) + + def _slice_second_grid_axis(axis_name, polytopes, datacube, datacube_transformations, second_axis_vals, path) -> list[Qube]: + result = [] + polytopes_on_axis = find_polytopes_on_axis(axis_name, polytopes) + + for poly in polytopes_on_axis: + ax = datacube._axes[axis_name] + lower, upper, slice_axis_idx = poly.extents(axis_name) + + # new_lower, new_upper = transform_upper_lower(axis_name, lower, upper, datacube) + # found_vals = [v for v in second_axis_vals if new_lower <= v <= new_upper] + found_vals = self.find_values_between(poly, ax, None, datacube, lower, upper, path) + + if len(found_vals) == 0: + continue + + # slice polytope along each value on child and keep resulting polytopes in memory + sliced_polys = [] + for val in found_vals: + # ax = datacube._axes[axis_name] + if not isinstance(ax, UnsliceableDatacubeAxis): + fval = ax.to_float(val) + # slice polytope along the value and add sliced polytope to list of polytopes in memory + sliced_poly = slice(poly, axis_name, fval, slice_axis_idx) + sliced_polys.append(sliced_poly) + # decide if axis should be compressed or not according to polytope + # NOTE: actually the second grid axis will always be compressed + + # if it's not compressed, need to separate into different nodes to append to the tree + + new_found_vals = [] + for found_val in found_vals: + found_val = self.remap_values(ax, found_val) + if isinstance(found_val, pd.Timedelta) or isinstance(found_val, pd.Timestamp): + new_found_vals.append(str(found_val)) + else: + new_found_vals.append(found_val) + + # NOTE this was the last axis so we do not have children... + + result.extend([Qube.make( + key=axis_name, + values=QEnum(new_found_vals), + metadata={}, + children={} + )]) + return result + + def _slice(q: Qube, polytopes, datacube, datacube_transformations) -> list[Qube]: + result = [] + + if len(q.children) == 0: + # add "fake" axes and their nodes in order -> what about merged axes?? + mapper_transformation = None + for transformation in list(datacube_transformations.values()): + if isinstance(transformation, DatacubeMapper): + mapper_transformation = transformation + if not mapper_transformation: + # There is no grid mapping + pass + else: + # Slice on the two grid axes + grid_axes = mapper_transformation._mapped_axes + + # Handle first grid axis + polytopes_on_axis = find_polytopes_on_axis(grid_axes[0], polytopes) + + for poly in polytopes_on_axis: + ax = datacube._axes[grid_axes[0]] + lower, upper, slice_axis_idx = poly.extents(grid_axes[0]) + + first_ax_vals = mapper_transformation.first_axis_vals() + + # new_lower, new_upper = transform_upper_lower(grid_axes[0], lower, upper, datacube) + # found_vals = [v for v in first_ax_vals if new_lower <= v <= new_upper] + found_vals = self.find_values_between(poly, ax, None, datacube, lower, upper) + + if len(found_vals) == 0: + continue + + # slice polytope along each value on child and keep resulting polytopes in memory + sliced_polys = [] + for val in found_vals: + # ax = datacube._axes[grid_axes[0]] + if not isinstance(ax, UnsliceableDatacubeAxis): + fval = ax.to_float(val) + # slice polytope along the value and add sliced polytope to list of polytopes in memory + sliced_poly = slice(poly, grid_axes[0], fval, slice_axis_idx) + sliced_polys.append(sliced_poly) + # decide if axis should be compressed or not according to polytope + # NOTE: actually the first grid axis will never be compressed + axis_compressed = (grid_axes[0] in self.compressed_axes) + + # if it's not compressed, need to separate into different nodes to append to the tree + for i, found_val in enumerate(found_vals): + found_val = self.remap_values(ax, found_val) + child_polytopes = [p for p in polytopes if p != poly] + if sliced_polys[i]: + child_polytopes.append(sliced_polys[i]) + + second_axis_vals = mapper_transformation.second_axis_vals([found_val]) + flattened_path = {grid_axes[0]: (found_val,)} + # get second axis children through slicing + children = _slice_second_grid_axis( + grid_axes[1], child_polytopes, datacube, datacube_transformations, second_axis_vals, flattened_path) + # If this node used to have children but now has none due to filtering, skip it. + if not children: + continue + if isinstance(found_val, pd.Timedelta) or isinstance(found_val, pd.Timestamp): + found_val = [str(found_val)] + + # TODO: remap the found_val using self.remap_values like in the hullslicer + + # TODO: when we have an axis that we would like to merge with another, we should skip the node creation here + # and instead keep/cache the value to merge with the node from before?? + + qube_node = Qube.make(key=grid_axes[0], + values=QEnum([found_val]), + metadata={}, + children=children) + result.append(qube_node) + + for i, child in enumerate(q.children): + # find polytopes which are defined on axis child.key + polytopes_on_axis = find_polytopes_on_axis(child.key, polytopes) + + # here now first change the values in the polytopes on the axis to reflect the axis type + + for poly in polytopes_on_axis: + ax = datacube._axes[child.key] + # find extents of polytope on child.key + lower, upper, slice_axis_idx = poly.extents(child.key) + + # # find values on child that are within extents + # # here first change the child values of the datacube ie the Qubed tree to their right type with the transformation + # modified_vals = change_datacube_val_types(child, datacube_transformations) + + # # here use the axis to transform lower and upper to right type too + # new_lower, new_upper = transform_upper_lower(child.key, lower, upper, datacube) + # found_vals = [v for v in modified_vals if new_lower <= v <= new_upper] + found_vals = self.find_values_between(poly, ax, child, datacube, lower, upper) + + if len(found_vals) == 0: + continue + + # slice polytope along each value on child and keep resulting polytopes in memory + sliced_polys = [] + for val in found_vals: + # ax = datacube._axes[child.key] + if not isinstance(ax, UnsliceableDatacubeAxis): + fval = ax.to_float(val) + # slice polytope along the value and add sliced polytope to list of polytopes in memory + sliced_poly = slice(poly, child.key, fval, slice_axis_idx) + sliced_polys.append(sliced_poly) + # decide if axis should be compressed or not according to polytope + axis_compressed = (child.key in self.compressed_axes) + # if it's not compressed, need to separate into different nodes to append to the tree + if not axis_compressed and len(found_vals) > 1: + for i, found_val in enumerate(found_vals): + found_val = self.remap_values(ax, found_val) + child_polytopes = [p for p in polytopes if p != poly] + if sliced_polys[i]: + child_polytopes.append(sliced_polys[i]) + children = _slice(child, child_polytopes, datacube, datacube_transformations) + # If this node used to have children but now has none due to filtering, skip it. + if child.children and not children: + continue + if isinstance(found_val, pd.Timedelta) or isinstance(found_val, pd.Timestamp): + found_val = [str(found_val)] + + # TODO: when we have an axis that we would like to merge with another, we should skip the node creation here + # and instead keep/cache the value to merge with the node from before?? + + qube_node = Qube.make(key=child.key, + values=QEnum(found_val), + metadata=child.metadata, + children=children) + result.append(qube_node) + else: + # if it's compressed, then can add all found values in a single node + child_polytopes = [p for p in polytopes if p != poly] + child_polytopes.extend( + [sliced_poly_ for sliced_poly_ in sliced_polys if sliced_poly_ is not None]) + # create children + children = _slice(child, child_polytopes, datacube, datacube_transformations) + # If this node used to have children but now has none due to filtering, skip it. + if child.children and not children: + continue + + new_found_vals = [] + for found_val in found_vals: + found_val = self.remap_values(ax, found_val) + if isinstance(found_val, pd.Timedelta) or isinstance(found_val, pd.Timestamp): + new_found_vals.append(str(found_val)) + else: + new_found_vals.append(found_val) + + result.extend([Qube.make( + key=child.key, + values=QEnum(new_found_vals), + metadata=child.metadata, + children=children + )]) + + return result + + return Qube.root_node(_slice(q, polytopes_to_slice, datacube, datacube_transformations)) + + def actual_slice(self, q: Qube, polytopes_to_slice, datacube, datacube_transformations): + + groups, input_axes = group(polytopes_to_slice) + combinations = tensor_product(groups) + + sub_trees = [] + + # NOTE: could optimise here if we know combinations will always be for one request. + # Then we do not need to create a new index tree and merge it to request, but can just + # directly work on request and return it... + + for c in combinations: + new_c = [] + for combi in c: + if isinstance(combi, list): + new_c.extend(combi) + else: + new_c.append(combi) + final_polys = [] + for poly in new_c: + if isinstance(poly, Product): + final_polys.extend(poly.polytope()) + else: + final_polys.append(poly) + + # Get the sliced Qube for each combi + r = self._actual_slice(q, final_polys, datacube, datacube_transformations) + sub_trees.append(r) + + final_tree = sub_trees[0] + + for sub_tree in sub_trees[1:]: + union(final_tree, sub_tree) + return final_tree + + def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): + self.find_compressed_axes(datacube, polytopes) + self.pre_process_polytopes(datacube, polytopes) + assert isinstance(datacube, QubedDatacube) + tree = self.actual_slice(datacube.q, polytopes, datacube, + datacube.datacube_transformations) + return tree diff --git a/tests/test_ecmwf_oper_data_fdb.py b/tests/test_ecmwf_oper_data_fdb.py index e848716c8..2efdfe4a8 100644 --- a/tests/test_ecmwf_oper_data_fdb.py +++ b/tests/test_ecmwf_oper_data_fdb.py @@ -56,7 +56,8 @@ def test_fdb_datacube(self): Select("class", ["od"]), Select("stream", ["oper"]), Select("type", ["fc"]), - Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), + # Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), + Box(["latitude", "longitude"], [0, 0], [80, 80]), ) self.fdbdatacube = gj.GribJump() self.slicer = HullSlicer() @@ -67,8 +68,8 @@ def test_fdb_datacube(self): ) result = self.API.retrieve(request) result.pprint() - assert len(result.leaves) == 3 - assert len(result.leaves[0].result) == 3 + # assert len(result.leaves) == 3 + # assert len(result.leaves[0].result) == 3 @pytest.mark.fdb def test_fdb_datacube_point(self): diff --git a/tests/test_qubed_extraction.py b/tests/test_qubed_extraction.py new file mode 100644 index 000000000..ef8e19d17 --- /dev/null +++ b/tests/test_qubed_extraction.py @@ -0,0 +1,114 @@ +# from qubed import Qube +# import requests +# from polytope_feature.datacube.datacube_axis import PandasTimedeltaDatacubeAxis, PandasTimestampDatacubeAxis, UnsliceableDatacubeAxis, FloatDatacubeAxis +# from polytope_feature.datacube.backends.test_qubed_slicing import actual_slice +# from polytope_feature.datacube.transformations.datacube_type_change.datacube_type_change import TypeChangeStrToTimestamp, TypeChangeStrToTimedelta +# import pandas as pd +# from polytope_feature.datacube.transformations.datacube_mappers.mapper_types.healpix_nested import NestedHealpixGridMapper + +# from polytope_feature.shapes import ConvexPolytope + + +# fdb_tree = Qube.from_json(requests.get( +# "https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate_dt.json").json()) + + +# # fdb_tree = fdb_tree.remove_by_key(["year"]).remove_by_key(["month"]) + +# fdb_tree.print() + +# print(fdb_tree.axes().keys()) + + +# # combi_polytopes = [ +# # ConvexPolytope(["param"], [["168"]]), +# # ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=12, minutes=0)]]), +# # ConvexPolytope(["resolution"], [["high"]]), +# # ConvexPolytope(["type"], [["fc"]]), +# # ConvexPolytope(["model"], [['ifs-nemo']]), +# # ConvexPolytope(["stream"], [["clte"]]), +# # ConvexPolytope(["realization"], ["1"]), +# # ConvexPolytope(["expver"], [['0001']]), +# # ConvexPolytope(["experiment"], [['ssp3-7.0']]), +# # ConvexPolytope(["generation"], [["1"]]), +# # ConvexPolytope(["levtype"], [["sfc"]]), +# # ConvexPolytope(["activity"], [["scenariomip"]]), +# # ConvexPolytope(["dataset"], [["climate-dt"]]), +# # ConvexPolytope(["class"], [["d1"]]), +# # ConvexPolytope(["date"], [[pd.Timestamp("20210728")], [pd.Timestamp("20210729")]]) +# # ] + +# # TODO: add lat/lon polygon +# combi_polytopes = [ +# ConvexPolytope(["param"], [["164"]]), +# ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=12, minutes=0)]]), +# ConvexPolytope(["resolution"], [["high"]]), +# ConvexPolytope(["type"], [["fc"]]), +# ConvexPolytope(["model"], [['ifs-nemo']]), +# ConvexPolytope(["stream"], [["clte"]]), +# ConvexPolytope(["realization"], ["1"]), +# ConvexPolytope(["expver"], [['0001']]), +# ConvexPolytope(["experiment"], [['ssp3-7.0']]), +# ConvexPolytope(["generation"], [["1"]]), +# ConvexPolytope(["levtype"], [["sfc"]]), +# ConvexPolytope(["activity"], [["scenariomip"]]), +# ConvexPolytope(["dataset"], [["climate-dt"]]), +# ConvexPolytope(["class"], [["d1"]]), +# ConvexPolytope(["date"], [[pd.Timestamp("20220811")], [pd.Timestamp("20220912")]]), +# ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]) +# ] + +# # TODO: add lat and lon axes +# datacube_axes = {"param": UnsliceableDatacubeAxis(), +# "time": PandasTimedeltaDatacubeAxis(), +# "resolution": UnsliceableDatacubeAxis(), +# "type": UnsliceableDatacubeAxis(), +# "model": UnsliceableDatacubeAxis(), +# "stream": UnsliceableDatacubeAxis(), +# "realization": UnsliceableDatacubeAxis(), +# "expver": UnsliceableDatacubeAxis(), +# "experiment": UnsliceableDatacubeAxis(), +# "generation": UnsliceableDatacubeAxis(), +# "levtype": UnsliceableDatacubeAxis(), +# "activity": UnsliceableDatacubeAxis(), +# "dataset": UnsliceableDatacubeAxis(), +# "class": UnsliceableDatacubeAxis(), +# "date": PandasTimestampDatacubeAxis(), +# "latitude": FloatDatacubeAxis(), +# "longitude": FloatDatacubeAxis()} + +# time_val = pd.Timedelta(hours=0, minutes=0) +# date_val = pd.Timestamp("20300101T000000") + + +# # TODO: add grid axis transformation +# datacube_transformations = { +# "time": TypeChangeStrToTimedelta("time", time_val), +# "date": TypeChangeStrToTimestamp("date", date_val), +# "values": NestedHealpixGridMapper("values", ["latitude", "longitude"], 1024) +# } + + +# sliced_tree = actual_slice(fdb_tree, combi_polytopes, datacube_axes, datacube_transformations) + + +# print("THE FINAL RESULT IS") +# print(sliced_tree) + +# # TODO: treat the transformations to talk to the qubed tree, maybe do it + +# # TODO: start iterating fdb_tree and creating a new request tree + +# # print(fdb_tree.) + + +# # Select("step", [0]), +# # Select("levtype", ["sfc"]), +# # Select("date", [pd.Timestamp("20231102T000000")]), +# # Select("domain", ["g"]), +# # Select("expver", ["0001"]), +# # Select("param", ["167"]), +# # Select("class", ["od"]), +# # Select("stream", ["oper"]), +# # Select("type", ["fc"]), +# # Box(["latitude", "longitude"], [0, 0], [80, 80]), diff --git a/tests/test_qubed_extraction_engine.py b/tests/test_qubed_extraction_engine.py new file mode 100644 index 000000000..91971ad78 --- /dev/null +++ b/tests/test_qubed_extraction_engine.py @@ -0,0 +1,253 @@ +from polytope_feature.shapes import Box, Select, Span +from polytope_feature.polytope import Polytope, Request +from polytope_feature.engine.qubed_slicer import QubedSlicer +from polytope_feature.datacube.backends.qubed import QubedDatacube +from polytope_feature.datacube.backends.fdb import FDBDatacube +import pytest +from qubed import Qube +import requests +from polytope_feature.datacube.datacube_axis import PandasTimedeltaDatacubeAxis, PandasTimestampDatacubeAxis, UnsliceableDatacubeAxis, FloatDatacubeAxis +# from polytope_feature.datacube.backends.test_qubed_slicing import actual_slice +from polytope_feature.datacube.transformations.datacube_type_change.datacube_type_change import TypeChangeStrToTimestamp, TypeChangeStrToTimedelta +import pandas as pd +from polytope_feature.datacube.transformations.datacube_mappers.mapper_types.healpix_nested import NestedHealpixGridMapper + +from polytope_feature.shapes import ConvexPolytope +import time +import pygribjump as gj +from polytope_feature.engine.hullslicer import HullSlicer + + +def find_relevant_subcube_from_request(request, qube_url): + + # NOTE: final url we want is like: + # "https://qubed.lumi.apps.dte.destination-earth.eu/api/v1/select/climate-dt/?class=d1&dataset=climate-dt" + + for shape in request.shapes: + if isinstance(shape, Select): + qube_url += shape.axis + "=" + for i, val in enumerate(shape.values): + qube_url += str(val) + if i < len(shape.values) - 1: + qube_url += "," + qube_url += "&" + # TODO: remove last unnecessary & + qube_url = qube_url[:-1] + return qube_url + + +fdb_tree = Qube.from_json(requests.get( + "https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate_dt.json").json()) + + +combi_polytopes = [ + ConvexPolytope(["param"], [["164"]]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=12, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + ConvexPolytope(["model"], [['ifs-nemo']]), + ConvexPolytope(["stream"], [["clte"]]), + ConvexPolytope(["realization"], ["1"]), + ConvexPolytope(["expver"], [['0001']]), + ConvexPolytope(["experiment"], [['ssp3-7.0']]), + ConvexPolytope(["generation"], [["1"]]), + ConvexPolytope(["levtype"], [["sfc"]]), + ConvexPolytope(["activity"], [["scenariomip"]]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20220811")], [pd.Timestamp("20220912")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]) +] + +# TODO: add lat and lon axes +datacube_axes = {"param": UnsliceableDatacubeAxis(), + "time": PandasTimedeltaDatacubeAxis(), + "resolution": UnsliceableDatacubeAxis(), + "type": UnsliceableDatacubeAxis(), + "model": UnsliceableDatacubeAxis(), + "stream": UnsliceableDatacubeAxis(), + "realization": UnsliceableDatacubeAxis(), + "expver": UnsliceableDatacubeAxis(), + "experiment": UnsliceableDatacubeAxis(), + "generation": UnsliceableDatacubeAxis(), + "levtype": UnsliceableDatacubeAxis(), + "activity": UnsliceableDatacubeAxis(), + "dataset": UnsliceableDatacubeAxis(), + "class": UnsliceableDatacubeAxis(), + "date": PandasTimestampDatacubeAxis(), + # "latitude": FloatDatacubeAxis(), + # "longitude": FloatDatacubeAxis() + } + +time_val = pd.Timedelta(hours=0, minutes=0) +date_val = pd.Timestamp("20300101T000000") + + +# TODO: add grid axis transformation +datacube_transformations = { + "time": TypeChangeStrToTimedelta("time", time_val), + "date": TypeChangeStrToTimestamp("date", date_val), + "values": NestedHealpixGridMapper("values", ["latitude", "longitude"], 1024) +} + + +options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, + # { + # "axis_name": "date", + # "transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}], + # }, + {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, + {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + "latitude", + "levtype", + "step", + "date", + "domain", + "expver", + "param", + "class", + "stream", + "type", + ], + "pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"}, +} + +# request = Request( +# Select("step", [0]), +# Select("levtype", ["sfc"]), +# Select("date", [pd.Timestamp("20230625T120000")]), +# Select("domain", ["g"]), +# Select("expver", ["0001"]), +# Select("param", ["167"]), +# Select("class", ["od"]), +# Select("stream", ["oper"]), +# Select("type", ["an"]), +# Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), +# ) + +request = Request(ConvexPolytope(["param"], [["164"]]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=1, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + ConvexPolytope(["model"], [['ifs-nemo']]), + ConvexPolytope(["stream"], [["clte"]]), + ConvexPolytope(["realization"], ["1"]), + ConvexPolytope(["expver"], [['0001']]), + ConvexPolytope(["experiment"], [['ssp3-7.0']]), + ConvexPolytope(["generation"], [["1"]]), + ConvexPolytope(["levtype"], [["sfc"]]), + ConvexPolytope(["activity"], [["scenariomip"]]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]) + # ConvexPolytope(["latitude", "longitude"], [[0, 0], [-0.5, -0.5], [0, -0.5]]) + ) + +qubeddatacube = QubedDatacube(fdb_tree, datacube_axes, datacube_transformations) +slicer = QubedSlicer() +self_API = Polytope( + datacube=qubeddatacube, + engine=slicer, + options=options, +) +time1 = time.time() +result = self_API.retrieve(request) +time2 = time.time() + +print(result) + +print("TIME EXTRACTING USING QUBED") +print(time2 - time1) + +# USING NORMAL GJ + + +options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, + # { + # "axis_name": "date", + # "transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}], + # }, + {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, + {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + # "latitude", + # "levtype", + # "step", + # "date", + # "domain", + # "expver", + # "param", + # "class", + # "stream", + # "type", + ], + "pre_path": {"class": "d1", "model": "ifs-nemo", "resolution": "high"}, +} + +fdbdatacube = gj.GribJump() +slicer = HullSlicer() +self_API = Polytope( + datacube=fdbdatacube, + engine=slicer, + options=options, +) + + +request = Request(ConvexPolytope(["param"], [["164"]]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=1, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + ConvexPolytope(["model"], [['ifs-nemo']]), + ConvexPolytope(["stream"], [["clte"]]), + ConvexPolytope(["realization"], ["1"]), + ConvexPolytope(["expver"], [['0001']]), + ConvexPolytope(["experiment"], [['ssp3-7.0']]), + ConvexPolytope(["generation"], [["1"]]), + ConvexPolytope(["levtype"], [["sfc"]]), + ConvexPolytope(["activity"], [["scenariomip"]]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]])) + +time3 = time.time() +result = self_API.retrieve(request) +time4 = time.time() + +print("TIME EXTRACTING USING GJ NORMAL") +print(time4 - time3) + + +# print(result) + +# print(result.leaves) + +# sliced_tree = actual_slice(fdb_tree, combi_polytopes, datacube_axes, datacube_transformations) diff --git a/tests/test_qubed_extraction_service.py b/tests/test_qubed_extraction_service.py new file mode 100644 index 000000000..ddf235a96 --- /dev/null +++ b/tests/test_qubed_extraction_service.py @@ -0,0 +1,269 @@ +from polytope_feature.shapes import Box, Select, Span +from polytope_feature.polytope import Polytope, Request +from polytope_feature.engine.qubed_slicer import QubedSlicer +from polytope_feature.datacube.backends.qubed import QubedDatacube +from polytope_feature.datacube.backends.fdb import FDBDatacube +import pytest +from qubed import Qube +import requests +from polytope_feature.datacube.datacube_axis import PandasTimedeltaDatacubeAxis, PandasTimestampDatacubeAxis, UnsliceableDatacubeAxis, FloatDatacubeAxis +# from polytope_feature.datacube.backends.test_qubed_slicing import actual_slice +from polytope_feature.datacube.transformations.datacube_type_change.datacube_type_change import TypeChangeStrToTimestamp, TypeChangeStrToTimedelta +import pandas as pd +from polytope_feature.datacube.transformations.datacube_mappers.mapper_types.healpix_nested import NestedHealpixGridMapper + +from polytope_feature.shapes import ConvexPolytope +import time +import pygribjump as gj +from polytope_feature.engine.hullslicer import HullSlicer + + +def find_relevant_subcube_from_request(request, qube_url): + + # NOTE: final url we want is like: + # "https://qubed.lumi.apps.dte.destination-earth.eu/api/v1/select/climate-dt/?class=d1&dataset=climate-dt" + + for shape in request.shapes: + if isinstance(shape, Select): + qube_url += shape.axis + "=" + for i, val in enumerate(shape.values): + qube_url += str(val) + if i < len(shape.values) - 1: + qube_url += "," + qube_url += "&" + # TODO: remove last unnecessary & + qube_url = qube_url[:-1] + return qube_url + + +def get_fdb_tree(request): + qube_url_start = "https://qubed.lumi.apps.dte.destination-earth.eu/api/v1/select/climate-dt/?" + qube_url = find_relevant_subcube_from_request(request, qube_url_start) + fdb_tree = Qube.from_json(requests.get(qube_url).json()) + return fdb_tree + + +fdb_tree = Qube.from_json(requests.get( + "https://github.com/ecmwf/qubed/raw/refs/heads/main/tests/example_qubes/climate_dt.json").json()) + + +# print(fdb_tree) + +# combi_polytopes = [ +# ConvexPolytope(["param"], [["164"]]), +# ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=12, minutes=0)]]), +# ConvexPolytope(["resolution"], [["high"]]), +# ConvexPolytope(["type"], [["fc"]]), +# ConvexPolytope(["model"], [['ifs-nemo']]), +# ConvexPolytope(["stream"], [["clte"]]), +# ConvexPolytope(["realization"], ["1"]), +# ConvexPolytope(["expver"], [['0001']]), +# ConvexPolytope(["experiment"], [['ssp3-7.0']]), +# ConvexPolytope(["generation"], [["1"]]), +# ConvexPolytope(["levtype"], [["sfc"]]), +# ConvexPolytope(["activity"], [["scenariomip"]]), +# ConvexPolytope(["dataset"], [["climate-dt"]]), +# ConvexPolytope(["class"], [["d1"]]), +# ConvexPolytope(["date"], [[pd.Timestamp("20220811")], [pd.Timestamp("20220912")]]), +# ConvexPolytope(["latitude", "longitude"], [[0, 0], [0.5, 0.5], [0, 0.5]]) +# ] + +# TODO: add lat and lon axes +datacube_axes = {"param": UnsliceableDatacubeAxis(), + "time": PandasTimedeltaDatacubeAxis(), + "resolution": UnsliceableDatacubeAxis(), + "type": UnsliceableDatacubeAxis(), + "model": UnsliceableDatacubeAxis(), + "stream": UnsliceableDatacubeAxis(), + "realization": UnsliceableDatacubeAxis(), + "expver": UnsliceableDatacubeAxis(), + "experiment": UnsliceableDatacubeAxis(), + "generation": UnsliceableDatacubeAxis(), + "levtype": UnsliceableDatacubeAxis(), + "activity": UnsliceableDatacubeAxis(), + "dataset": UnsliceableDatacubeAxis(), + "class": UnsliceableDatacubeAxis(), + "date": PandasTimestampDatacubeAxis(), + "latitude": FloatDatacubeAxis(), + "longitude": FloatDatacubeAxis()} + +time_val = pd.Timedelta(hours=0, minutes=0) +date_val = pd.Timestamp("20300101T000000") + + +# TODO: add grid axis transformation +datacube_transformations = { + "time": TypeChangeStrToTimedelta("time", time_val), + "date": TypeChangeStrToTimestamp("date", date_val), + "values": NestedHealpixGridMapper("values", ["latitude", "longitude"], 1024) +} + + +options = { + "axis_config": [ + {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, + {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, + # { + # "axis_name": "date", + # "transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}], + # }, + # {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, + # {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, + { + "axis_name": "values", + "transformations": [ + {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} + ], + }, + {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, + {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, + ], + "compressed_axes_config": [ + "longitude", + "latitude", + "levtype", + "step", + "date", + "domain", + "expver", + "param", + "class", + "stream", + "type", + ], + "pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"}, +} + +# request = Request( +# Select("step", [0]), +# Select("levtype", ["sfc"]), +# Select("date", [pd.Timestamp("20230625T120000")]), +# Select("domain", ["g"]), +# Select("expver", ["0001"]), +# Select("param", ["167"]), +# Select("class", ["od"]), +# Select("stream", ["oper"]), +# Select("type", ["an"]), +# Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), +# ) + +request = Request( + # ConvexPolytope(["param"], [["164"]]), + Select("param", ["164"]), + ConvexPolytope(["time"], [[pd.Timedelta(hours=0, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), + ConvexPolytope(["resolution"], [["high"]]), + ConvexPolytope(["type"], [["fc"]]), + # ConvexPolytope(["model"], [['ifs-nemo']]), + Select("model", ["ifs-nemo"]), + Select("stream", ["clte"]), + # ConvexPolytope(["stream"], [["clte"]]), + ConvexPolytope(["realization"], ["1"]), + ConvexPolytope(["expver"], [['0001']]), + ConvexPolytope(["experiment"], [['ssp3-7.0']]), + ConvexPolytope(["generation"], [["1"]]), + ConvexPolytope(["levtype"], [["sfc"]]), + # ConvexPolytope(["activity"], [["scenariomip"]]), + Select("activity", ["scenariomip"]), + ConvexPolytope(["dataset"], [["climate-dt"]]), + ConvexPolytope(["class"], [["d1"]]), + ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), + ConvexPolytope(["latitude", "longitude"], [[0, 0], [5, 5], [0, 5]])) + +fdb_tree = get_fdb_tree(request) + +print("HERE WE HAVE THE FDB TREE") +print(fdb_tree) + +qubeddatacube = QubedDatacube(fdb_tree, datacube_axes, datacube_transformations) +slicer = QubedSlicer() +self_API = Polytope( + datacube=qubeddatacube, + engine=slicer, + options=options, +) +time1 = time.time() +result = self_API.retrieve(request) +time2 = time.time() + +# print(result) + +print("TIME EXTRACTING USING QUBED") +print(time2 - time1) + +# # USING NORMAL GJ + + +# options = { +# "axis_config": [ +# {"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]}, +# {"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]}, +# # { +# # "axis_name": "date", +# # "transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}], +# # }, +# {"axis_name": "date", "transformations": [{"name": "type_change", "type": "date"}]}, +# {"axis_name": "time", "transformations": [{"name": "type_change", "type": "time"}]}, +# { +# "axis_name": "values", +# "transformations": [ +# {"name": "mapper", "type": "healpix_nested", "resolution": 1024, "axes": ["latitude", "longitude"]} +# ], +# }, +# {"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]}, +# {"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]}, +# ], +# "compressed_axes_config": [ +# "longitude", +# # "latitude", +# # "levtype", +# # "step", +# # "date", +# # "domain", +# # "expver", +# # "param", +# # "class", +# # "stream", +# # "type", +# ], +# "pre_path": {"class": "d1", "model": "ifs-nemo", "resolution": "high"}, +# } + +# fdbdatacube = gj.GribJump() +# slicer = HullSlicer() +# self_API = Polytope( +# datacube=fdbdatacube, +# engine=slicer, +# options=options, +# ) + + +# request = Request(ConvexPolytope(["param"], [["164"]]), +# ConvexPolytope(["time"], [[pd.Timedelta(hours=1, minutes=0)], [pd.Timedelta(hours=3, minutes=0)]]), +# ConvexPolytope(["resolution"], [["high"]]), +# ConvexPolytope(["type"], [["fc"]]), +# ConvexPolytope(["model"], [['ifs-nemo']]), +# ConvexPolytope(["stream"], [["clte"]]), +# ConvexPolytope(["realization"], ["1"]), +# ConvexPolytope(["expver"], [['0001']]), +# ConvexPolytope(["experiment"], [['ssp3-7.0']]), +# ConvexPolytope(["generation"], [["1"]]), +# ConvexPolytope(["levtype"], [["sfc"]]), +# ConvexPolytope(["activity"], [["scenariomip"]]), +# ConvexPolytope(["dataset"], [["climate-dt"]]), +# ConvexPolytope(["class"], [["d1"]]), +# ConvexPolytope(["date"], [[pd.Timestamp("20220811")]]), +# ConvexPolytope(["latitude", "longitude"], [[0, 0], [5, 5], [0, 5]])) + +# time3 = time.time() +# result = self_API.retrieve(request) +# time4 = time.time() + +# print("TIME EXTRACTING USING GJ NORMAL") +# print(time4 - time3) + + +# # print(result) + +# # print(result.leaves) + +# # sliced_tree = actual_slice(fdb_tree, combi_polytopes, datacube_axes, datacube_transformations)