|
| 1 | +from collections.abc import Callable |
| 2 | +import numpy as np |
| 3 | +from keras.saving import ( |
| 4 | + deserialize_keras_object as deserialize, |
| 5 | + register_keras_serializable as serializable, |
| 6 | + serialize_keras_object as serialize, |
| 7 | + get_registered_name, |
| 8 | + get_registered_object, |
| 9 | +) |
| 10 | +from .elementwise_transform import ElementwiseTransform |
| 11 | +from ...utils import filter_kwargs |
| 12 | +import inspect |
| 13 | + |
| 14 | + |
| 15 | +@serializable(package="bayesflow.adapters") |
| 16 | +class SerializableCustomTransform(ElementwiseTransform): |
| 17 | + """ |
| 18 | + Transforms a parameter using a pair of registered serializable forward and inverse functions. |
| 19 | +
|
| 20 | + Parameters |
| 21 | + ---------- |
| 22 | + forward : function, no lambda |
| 23 | + Registered serializable function to transform the data in the forward pass. |
| 24 | + For the adapter to be serializable, this function has to be serializable |
| 25 | + as well (see Notes). Therefore, only proper functions and no lambda |
| 26 | + functions can be used here. |
| 27 | + inverse : function, no lambda |
| 28 | + Function to transform the data in the inverse pass. |
| 29 | + For the adapter to be serializable, this function has to be serializable |
| 30 | + as well (see Notes). Therefore, only proper functions and no lambda |
| 31 | + functions can be used here. |
| 32 | +
|
| 33 | + Raises |
| 34 | + ------ |
| 35 | + ValueError |
| 36 | + When the provided functions are not registered serializable functions. |
| 37 | +
|
| 38 | + Notes |
| 39 | + ----- |
| 40 | + Important: The forward and inverse functions have to be registered with Keras. |
| 41 | + To do so, use the `@keras.saving.register_keras_serializable` decorator. |
| 42 | + They must also be registered (and identical) when loading the adapter |
| 43 | + at a later point in time. |
| 44 | +
|
| 45 | + """ |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + *, |
| 50 | + forward: Callable[[np.ndarray, ...], np.ndarray], |
| 51 | + inverse: Callable[[np.ndarray, ...], np.ndarray], |
| 52 | + ): |
| 53 | + super().__init__() |
| 54 | + |
| 55 | + self._check_serializable(forward, label="forward") |
| 56 | + self._check_serializable(inverse, label="inverse") |
| 57 | + self._forward = forward |
| 58 | + self._inverse = inverse |
| 59 | + |
| 60 | + @classmethod |
| 61 | + def _check_serializable(cls, function, label=""): |
| 62 | + GENERAL_EXAMPLE_CODE = ( |
| 63 | + "The example code below shows the structure of a correctly decorated function:\n\n" |
| 64 | + "```\n" |
| 65 | + "import keras\n\n" |
| 66 | + "@keras.saving.register_keras_serializable('custom')\n" |
| 67 | + f"def my_{label}(...):\n" |
| 68 | + " [your code goes here...]\n" |
| 69 | + "```\n" |
| 70 | + ) |
| 71 | + if function is None: |
| 72 | + raise TypeError( |
| 73 | + f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}" |
| 74 | + ) |
| 75 | + registered_name = get_registered_name(function) |
| 76 | + # check if function is a lambda function |
| 77 | + if registered_name == "<lambda>": |
| 78 | + raise ValueError( |
| 79 | + f"The provided function for '{label}' is a lambda function, " |
| 80 | + "which cannot be serialized. " |
| 81 | + "Please provide a registered serializable function by using the " |
| 82 | + "@keras.saving.register_keras_serializable decorator." |
| 83 | + f"\n{GENERAL_EXAMPLE_CODE}" |
| 84 | + ) |
| 85 | + if inspect.ismethod(function): |
| 86 | + raise ValueError( |
| 87 | + f"The provided value for '{label}' is a method, not a function. " |
| 88 | + "Methods cannot be serialized separately from their classes. " |
| 89 | + "Please provide a registered serializable function instead by " |
| 90 | + "moving the functionality to a function (i.e., outside of the class) and " |
| 91 | + "using the @keras.saving.register_keras_serializable decorator." |
| 92 | + f"\n{GENERAL_EXAMPLE_CODE}" |
| 93 | + ) |
| 94 | + registered_object_for_name = get_registered_object(registered_name) |
| 95 | + if registered_object_for_name is None: |
| 96 | + try: |
| 97 | + source_max_lines = 5 |
| 98 | + function_source_code = inspect.getsource(function).split("\n") |
| 99 | + if len(function_source_code) > source_max_lines: |
| 100 | + function_source_code = function_source_code[:source_max_lines] + [" [...]"] |
| 101 | + |
| 102 | + example_code = "For your provided function, this would look like this:\n\n" |
| 103 | + example_code += "\n".join( |
| 104 | + ["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"] |
| 105 | + + function_source_code |
| 106 | + + ["```"] |
| 107 | + ) |
| 108 | + except OSError: |
| 109 | + example_code = GENERAL_EXAMPLE_CODE |
| 110 | + raise ValueError( |
| 111 | + f"The provided function for '{label}' is not registered with Keras.\n" |
| 112 | + "Please register the function using the " |
| 113 | + "@keras.saving.register_keras_serializable decorator.\n" |
| 114 | + f"{example_code}" |
| 115 | + ) |
| 116 | + if registered_object_for_name is not function: |
| 117 | + raise ValueError( |
| 118 | + f"The provided function for '{label}' does not match the function " |
| 119 | + f"registered under its name '{registered_name}'. " |
| 120 | + f"(registered function: {registered_object_for_name}, provided function: {function}). " |
| 121 | + ) |
| 122 | + |
| 123 | + @classmethod |
| 124 | + def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform": |
| 125 | + if get_registered_object(config["forward"]["config"], custom_objects) is None: |
| 126 | + provided_function_msg = "" |
| 127 | + if config["_forward_source_code"]: |
| 128 | + provided_function_msg = ( |
| 129 | + f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```" |
| 130 | + ) |
| 131 | + raise TypeError( |
| 132 | + "\n\nPLEASE READ HERE:\n" |
| 133 | + "-----------------\n" |
| 134 | + "The forward function that was provided as `forward` " |
| 135 | + "is not registered with Keras, making deserialization impossible. " |
| 136 | + f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original " |
| 137 | + "function before loading your model." |
| 138 | + f"{provided_function_msg}" |
| 139 | + ) |
| 140 | + if get_registered_object(config["inverse"]["config"], custom_objects) is None: |
| 141 | + provided_function_msg = "" |
| 142 | + if config["_inverse_source_code"]: |
| 143 | + provided_function_msg = ( |
| 144 | + f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```" |
| 145 | + ) |
| 146 | + raise TypeError( |
| 147 | + "\n\nPLEASE READ HERE:\n" |
| 148 | + "-----------------\n" |
| 149 | + "The inverse function that was provided as `inverse` " |
| 150 | + "is not registered with Keras, making deserialization impossible. " |
| 151 | + f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original " |
| 152 | + "function before loading your model." |
| 153 | + f"{provided_function_msg}" |
| 154 | + ) |
| 155 | + forward = deserialize(config["forward"], custom_objects) |
| 156 | + inverse = deserialize(config["inverse"], custom_objects) |
| 157 | + return cls( |
| 158 | + forward=forward, |
| 159 | + inverse=inverse, |
| 160 | + ) |
| 161 | + |
| 162 | + def get_config(self) -> dict: |
| 163 | + forward_source_code = inverse_source_code = None |
| 164 | + try: |
| 165 | + forward_source_code = inspect.getsource(self._forward) |
| 166 | + inverse_source_code = inspect.getsource(self._inverse) |
| 167 | + except OSError: |
| 168 | + pass |
| 169 | + return { |
| 170 | + "forward": serialize(self._forward), |
| 171 | + "inverse": serialize(self._inverse), |
| 172 | + "_forward_source_code": forward_source_code, |
| 173 | + "_inverse_source_code": inverse_source_code, |
| 174 | + } |
| 175 | + |
| 176 | + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: |
| 177 | + # filter kwargs so that other transform args like batch_size, strict, ... are not passed through |
| 178 | + kwargs = filter_kwargs(kwargs, self._forward) |
| 179 | + return self._forward(data, **kwargs) |
| 180 | + |
| 181 | + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: |
| 182 | + kwargs = filter_kwargs(kwargs, self._inverse) |
| 183 | + return self._inverse(data, **kwargs) |
0 commit comments