[PR]: Improving regrid2 performance#533
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #533 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 15 15
Lines 1602 1588 -14
=========================================
- Hits 1602 1588 -14 ☔ View full report in Codecov by Sentry. |
|
Hey @jasonb5 just checking in to see your estimated timeline for when this will be ready for review and merge. I'm shooting to have xCDAT v0.6.0 released in the next week or so. |
|
Notes from 9/13/23 meeting:
|
|
10/11/23 Meeting Notes: Next steps:
|
|
Any status updates here? |
|
@tomvothecoder @lee1043 @chengzhuzhang @pochedls
I placed the notebook under If everything looks alright lets merge and I'll get out the next few fixes and continue working on performance. |
|
Thank you for the update. Sure, I can test the branch. |
tomvothecoder
left a comment
There was a problem hiding this comment.
Hi @jasonb5, here's my initial code review with questions and minor suggestions.
| mapping = [ | ||
| np.where( | ||
| np.logical_and( | ||
| shifted_src_west < dst_east[i], shifted_src_east > dst_west[i] | ||
| shifted_src_west < dst_east[x], shifted_src_east > dst_west[x] | ||
| ) | ||
| )[0] | ||
| for x in range(dst_length) | ||
| ] | ||
|
|
||
| weight = np.minimum(dst_east[i], shifted_src_east[contrib]) - np.maximum( | ||
| dst_west[i], shifted_src_west[contrib] | ||
| ) | ||
|
|
||
| weights.append(weight.values.reshape(1, contrib.shape[0])) | ||
|
|
||
| contrib += shift | ||
| weights = [ | ||
| ( | ||
| np.minimum(dst_east[x], shifted_src_east[y]) | ||
| - np.maximum(dst_west[x], shifted_src_west[y]) | ||
| ).reshape((1, -1)) | ||
| for x, y in enumerate(mapping) | ||
| ] |
There was a problem hiding this comment.
Same comment about adding comment to explain logic and purpose
xcdat/regridder/regrid2.py
Outdated
| name = input_data_var.cf.axes[cf_axis_name] | ||
|
|
||
| if isinstance(name, list): | ||
| name = name[0] |
There was a problem hiding this comment.
Unless the intent here is to only interpret the axis CF attribute?
Instead of using cf_xarray directly, I think you can use xc.get_dim_keys() which can also interpret the standard_name attribute or use the xCDAT fall-back table of generally accepted axis names.
There was a problem hiding this comment.
I could use get_dim_keys but we will be trading performance for robustness. If we accept this I'm fine making the change.
There was a problem hiding this comment.
I see. Is the performance hit significant using get_dim_keys()? If so, I think it is fine to only interpret the axis attribute for performance.
@lee1043 any thoughts?
There was a problem hiding this comment.
As this part of the code is in the back-end level that less likely be accessed by users, I would prefer prioritizing performance, unless the robustness trading off is too significant. How much performance change this would make?
There was a problem hiding this comment.
Try passing data variable directly to get_dim_keys()
There was a problem hiding this comment.
Looks like using get_dim_keys works just fine and no decrease in performance, actually a small increase.
| try: | ||
| name = ds.cf.bounds[axis][0] | ||
| except (KeyError, IndexError): | ||
| raise RuntimeError(f"Could not determine {axis!r} bounds") |
There was a problem hiding this comment.
| try: | |
| name = ds.cf.bounds[axis][0] | |
| except (KeyError, IndexError): | |
| raise RuntimeError(f"Could not determine {axis!r} bounds") | |
| try: | |
| name = ds.bounds.get_bounds(axis) | |
| except (ValueError, KeyError): | |
| raise RuntimeError(f"Could not determine {axis!r} bounds") |
I think you can use xCDAT's ds.bounds.get_bounds().
There was a problem hiding this comment.
As from above I could use get_bounds but we're again trading performance for robustness.
There was a problem hiding this comment.
Gotcha. If the current implementation is faster and using .get_bounds() isn't necessary, we can keep your current implementation.
There was a problem hiding this comment.
After some review I'm going to leave the ds.cf.bounds usage. In this case it's fine to use as it's being called on the input/output grid, both of which have been generated/validated by previous code. We can guarantee that both grid objects only have lat/lon coordinates/bounds with the correct metadata thus we do not need the robustness of ds.bounds.get_bounds.
| for y in range(y_length): | ||
| y_seg = np.take(input_data, lat_mapping[y], axis=y_index) | ||
|
|
||
| for lon_index, lon_map in enumerate(self._lon_mapping): | ||
| lon_weight = self._lon_weights[lon_index] | ||
| for x in range(x_length): | ||
| x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap") | ||
|
|
||
| dot_weight = np.dot(lat_weight, lon_weight) | ||
| cell_weight = np.dot(lat_weights[y], lon_weights[x]) | ||
|
|
||
| cell_weight = np.sum(dot_weight) | ||
| output_seg_index = y * x_length + x | ||
|
|
||
| input_lon_segment = np.take( | ||
| input_lat_segment, lon_map, axis=input_lon_index | ||
| if is_2d: | ||
| output_data[output_seg_index] = np.divide( | ||
| np.sum( | ||
| np.multiply(x_seg, cell_weight), | ||
| axis=(y_index, x_index), | ||
| ), | ||
| np.sum(cell_weight), | ||
| ) | ||
|
|
||
| data = ( | ||
| np.nansum( | ||
| np.multiply(input_lon_segment, dot_weight), | ||
| axis=(input_lat_index, input_lon_index), | ||
| ) | ||
| / cell_weight | ||
| else: | ||
| output_seg = output_data[output_seg_index] | ||
|
|
||
| np.divide( | ||
| np.sum( | ||
| np.multiply(x_seg, cell_weight), | ||
| axis=(y_index, x_index), | ||
| ), | ||
| np.sum(cell_weight), | ||
| out=output_seg, |
There was a problem hiding this comment.
Comments explaining the logic here would be good. Maybe in the docstring.
lee1043
left a comment
There was a problem hiding this comment.
Sorry for the delayed approval. I thought I have marked approval but missed.
Description
After some analysis, I determined we were losing some performance moving back and forth between xarray and numpy.
The first fix ensures we're doing all the heavy computation in numpy. This has reduce the example time from ~4.8583 to ~1.6833.
There's still some more room for improvement.
Checklist
If applicable: