spotpython-0.16.9 released

Changes in spotpython-0.16.9:

  • xai.py: add new function viz_net to visualize the network architecture (linear nets).
  • dimensions.py: add new function extract_linear_dims that extracts the input and output dimensions of the Linear layers in a PyTorch model.

Here is an example which demonstrates the usage of the viz_net() function:

from spotpython.plot.xai import viz_net
from spotpython.utils.init import fun_control_init
from spotpython.data.diabetes import Diabetes
from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
from spotpython.hyperparameters.values import (
        get_default_hyperparameters_as_array, get_one_config_from_X)
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
_L_in=10
_L_out=1
_torchmetric="mean_squared_error"
fun_control = fun_control_init(
    _L_in=_L_in,
    _L_out=_L_out,
    _torchmetric=_torchmetric,
    data_set=Diabetes(),
    core_model=NNLinearRegressor,
    hyperdict=LightHyperDict)
X = get_default_hyperparameters_as_array(fun_control)
config = get_one_config_from_X(X, fun_control)
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
viz_net(net=model, device="cpu", show_attrs=True, show_saved=True, filename="model_architecture", format="png")