View Source Axon.Losses (Axon v0.7.0)

Loss functions.

Loss functions evaluate predictions with respect to true data, often to measure the divergence between a model's representation of the data-generating distribution and the true representation of the data-generating distribution.

Each loss function is implemented as an element-wise function measuring the loss with respect to the input target y_true and input prediction y_pred. As an example, the mean_squared_error/2 loss function produces a tensor whose values are the mean squared error between targets and predictions:

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_squared_error(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.5, 0.5]
>

It's common to compute the loss across an entire minibatch. You can easily do so by specifying a :reduction mode, or by composing one of these with an Nx reduction method:

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_squared_error(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.5
>

You can even compose loss functions:

defn my_strange_loss(y_true, y_pred) do
  y_true
  |> Axon.Losses.mean_squared_error(y_pred)
  |> Axon.Losses.binary_cross_entropy(y_pred)
  |> Nx.sum()
end

Or, more commonly, you can combine loss functions with penalties for regularization:

defn regularized_loss(params, y_true, y_pred) do
  loss = Axon.mean_squared_error(y_true, y_pred)
  penalty = l2_penalty(params)
  Nx.sum(loss) + penalty
end

All of the functions in this module are implemented as numerical functions and can be JIT or AOT compiled with any supported Nx compiler.

Summary

Functions

Applies label smoothing to the given labels.

Binary cross-entropy loss function.

Categorical cross-entropy loss function.

Categorical hinge loss function.

Connectionist Temporal Classification loss.

Cosine Similarity error loss function.

Hinge loss function.

Kullback-Leibler divergence loss function.

Modifies the given loss function to smooth labels prior to calculating loss.

Logarithmic-Hyperbolic Cosine loss function.

Margin ranking loss function.

Mean-absolute error loss function.

Mean-squared error loss function.

Poisson loss function.

Soft margin loss function.

Functions

Link to this function

apply_label_smoothing(y_true, y_pred, opts \\ [])

View Source

Applies label smoothing to the given labels.

Label smoothing is a regularization technique which shrink targets towards a uniform distribution. Label smoothing can improve model generalization.

Options

  • :smoothing - smoothing factor. Defaults to 0.1

References

Link to this function

binary_cross_entropy(y_true, y_pred, opts \\ [])

View Source

Binary cross-entropy loss function.

$$ l_i = -\frac{1}{2}(\hat{y_i} \cdot \log(y_i) + (1 - \hat{y_i}) \cdot \log(1 - y_i)) $$

Binary cross-entropy loss is most often used in binary classification problems. By default, it expects y_pred to encode probabilities from [0.0, 1.0], typically as the output of the sigmoid function or another function which squeezes values between 0 and 1. You may optionally set from_logits: true to specify that values are being sent as non-normalized values (e.g. weights with possibly infinite range). In this case, input values will be encoded as probabilities by applying the logistic sigmoid function before computing loss.

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

  • :negative_weight - class weight for 0 class useful for scaling loss by importance of class. Defaults to 1.0.

  • :positive_weight - class weight for 1 class useful for scaling loss by importance of class. Defaults to 1.0.

  • :from_logits - whether y_pred is a logits tensor. Defaults to false.

Examples

iex> y_true = Nx.tensor([[0, 1], [1, 0], [1, 0]])
iex> y_pred = Nx.tensor([[0.6811, 0.5565], [0.6551, 0.4551], [0.5422, 0.2648]])
iex> Axon.Losses.binary_cross_entropy(y_true, y_pred)
#Nx.Tensor<
  f32[3]
  [0.8644826412200928, 0.5150600075721741, 0.45986634492874146]
>

iex> y_true = Nx.tensor([[0, 1], [1, 0], [1, 0]])
iex> y_pred = Nx.tensor([[0.6811, 0.5565], [0.6551, 0.4551], [0.5422, 0.2648]])
iex> Axon.Losses.binary_cross_entropy(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.613136351108551
>

iex> y_true = Nx.tensor([[0, 1], [1, 0], [1, 0]])
iex> y_pred = Nx.tensor([[0.6811, 0.5565], [0.6551, 0.4551], [0.5422, 0.2648]])
iex> Axon.Losses.binary_cross_entropy(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  1.8394089937210083
>
Link to this function

categorical_cross_entropy(y_true, y_pred, opts \\ [])

View Source

Categorical cross-entropy loss function.

$$ l_i = -\sum_i^C \hat{y_i} \cdot \log(y_i) $$

Categorical cross-entropy is typically used for multi-class classification problems. By default, it expects y_pred to encode a probability distribution along the last axis. You can specify from_logits: true to indicate y_pred is a logits tensor.

# Batch size of 3 with 3 target classes
y_true = Nx.tensor([0, 2, 1])
y_pred = Nx.tensor([[0.2, 0.8, 0.0], [0.1, 0.2, 0.7], [0.1, 0.2, 0.7]])

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

  • :class_weights - 1-D list corresponding to weight of each class useful for scaling loss according to importance of class. Tensor size must match number of classes in dataset. Defaults to 1.0 for all classes.

  • :from_logits - whether y_pred is a logits tensor. Defaults to false.

  • :sparse - whether y_true encodes a "sparse" tensor. In this case the inputs are integer values corresponding to the target class. Defaults to false.

Examples

iex> y_true = Nx.tensor([[0, 1, 0], [0, 0, 1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
iex> Axon.Losses.categorical_cross_entropy(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.051293306052684784, 2.3025851249694824]
>

iex> y_true = Nx.tensor([[0, 1, 0], [0, 0, 1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
iex> Axon.Losses.categorical_cross_entropy(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  1.1769392490386963
>

iex> y_true = Nx.tensor([[0, 1, 0], [0, 0, 1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
iex> Axon.Losses.categorical_cross_entropy(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  2.3538784980773926
>

iex> y_true = Nx.tensor([1, 2], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
iex> Axon.Losses.categorical_cross_entropy(y_true, y_pred, reduction: :sum, sparse: true)
#Nx.Tensor<
  f32
  2.3538784980773926
>
Link to this function

categorical_hinge(y_true, y_pred, opts \\ [])

View Source

Categorical hinge loss function.

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([[1, 0, 0], [0, 0, 1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.05300799, 0.21617081, 0.68642382], [0.3754382 , 0.08494169, 0.13442067]])
iex> Axon.Losses.categorical_hinge(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [1.6334158182144165, 1.2410175800323486]
>

iex> y_true = Nx.tensor([[1, 0, 0], [0, 0, 1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.05300799, 0.21617081, 0.68642382], [0.3754382 , 0.08494169, 0.13442067]])
iex> Axon.Losses.categorical_hinge(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  1.4372167587280273
>

iex> y_true = Nx.tensor([[1, 0, 0], [0, 0, 1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.05300799, 0.21617081, 0.68642382], [0.3754382 , 0.08494169, 0.13442067]])
iex> Axon.Losses.categorical_hinge(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  2.8744335174560547
>
Link to this function

connectionist_temporal_classification(arg1, y_pred, opts \\ [])

View Source

Connectionist Temporal Classification loss.

Argument Shapes

  • l_true - $(B)$
  • y_true - $(B, S)$
  • y_pred - $(B, T, D)$

Options

  • :reduction - reduction mode. One of :sum or :none. Defaults to :none.

Description

l_true contains lengths of target sequences. Nonzero positive values. y_true contains target sequences. Each value represents a class of element in range of available classes 0 <= y < D. Blank element class is included in this range, but shouldn't be presented among y_true values. Maximum target sequence length should be lower or equal to y_pred sequence length: S <= T. y_pred - log probabilities of classes D along the prediction sequence T.

Link to this function

cosine_similarity(y_true, y_pred, opts \\ [])

View Source

Cosine Similarity error loss function.

$$ l_i = \sum_i (\hat{y_i} - y_i)^2 $$

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.
  • :axes - Defaults to [1].
  • :eps - Defaults to 1.0e-6.

Examples

iex> y_pred = Nx.tensor([[1.0, 0.0], [1.0, 1.0]])
iex> y_true = Nx.tensor([[0.0, 1.0], [1.0, 1.0]])
iex> Axon.Losses.cosine_similarity(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.0, 1.0000001192092896]
>
Link to this function

hinge(y_true, y_pred, opts \\ [])

View Source

Hinge loss function.

$$ \frac{1}{C}\max_i(1 - \hat{y_i} * y_i, 0) $$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Examples

iex> y_true = Nx.tensor([[ 1,  1, -1], [ 1,  1, -1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.45440044, 0.31470688, 0.67920924], [0.24311459, 0.93466766, 0.10914676]])
iex> Axon.Losses.hinge(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.9700339436531067, 0.6437881588935852]
>

iex> y_true = Nx.tensor([[ 1,  1, -1], [ 1,  1, -1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.45440044, 0.31470688, 0.67920924], [0.24311459, 0.93466766, 0.10914676]])
iex> Axon.Losses.hinge(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.806911051273346
>

iex> y_true = Nx.tensor([[ 1,  1, -1], [ 1,  1, -1]], type: {:s, 8})
iex> y_pred = Nx.tensor([[0.45440044, 0.31470688, 0.67920924], [0.24311459, 0.93466766, 0.10914676]])
iex> Axon.Losses.hinge(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  1.613822102546692
>
Link to this function

huber(y_true, y_pred, opts \\ [])

View Source

Huber loss.

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

  • :delta - the point where the Huber loss function changes from a quadratic to linear. Defaults to 1.0.

Examples

iex> y_true = Nx.tensor([[1], [1.5], [2.0]])
iex> y_pred = Nx.tensor([[0.8], [1.8], [2.1]])
iex> Axon.Losses.huber(y_true, y_pred)
#Nx.Tensor<
  f32[3][1]
  [
    [0.019999997690320015],
    [0.04499998688697815],
    [0.004999990575015545]
  ]
>

iex> y_true = Nx.tensor([[1], [1.5], [2.0]])
iex> y_pred = Nx.tensor([[0.8], [1.8], [2.1]])
iex> Axon.Losses.huber(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.02333332598209381
>
Link to this function

kl_divergence(y_true, y_pred, opts \\ [])

View Source

Kullback-Leibler divergence loss function.

$$ l_i = \sum_i^C \hat{y_i} \cdot \log(\frac{\hat{y_i}}{y_i}) $$

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([[0, 1], [0, 0]], type: {:u, 8})
iex> y_pred = Nx.tensor([[0.6, 0.4], [0.4, 0.6]])
iex> Axon.Losses.kl_divergence(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.916289210319519, -3.080907390540233e-6]
>

iex> y_true = Nx.tensor([[0, 1], [0, 0]], type: {:u, 8})
iex> y_pred = Nx.tensor([[0.6, 0.4], [0.4, 0.6]])
iex> Axon.Losses.kl_divergence(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.45814305543899536
>

iex> y_true = Nx.tensor([[0, 1], [0, 0]], type: {:u, 8})
iex> y_pred = Nx.tensor([[0.6, 0.4], [0.4, 0.6]])
iex> Axon.Losses.kl_divergence(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  0.9162861108779907
>
Link to this function

label_smoothing(loss_fun, opts \\ [])

View Source

Modifies the given loss function to smooth labels prior to calculating loss.

See apply_label_smoothing/2 for details.

Options

  • :smoothing - smoothing factor. Defaults to 0.1
Link to this function

log_cosh(y_true, y_pred, opts \\ [])

View Source

Logarithmic-Hyperbolic Cosine loss function.

$$ l_i = \frac{1}{C} \sum_i^C (\hat{y_i} - y_i) + \log(1 + e^{-2(\hat{y_i} - y_i)}) - \log(2) $$

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]])
iex> y_pred = Nx.tensor([[1.0, 1.0], [0.0, 0.0]])
iex> Axon.Losses.log_cosh(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.2168903946876526, 0.0]
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]])
iex> y_pred = Nx.tensor([[1.0, 1.0], [0.0, 0.0]])
iex> Axon.Losses.log_cosh(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.1084451973438263
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]])
iex> y_pred = Nx.tensor([[1.0, 1.0], [0.0, 0.0]])
iex> Axon.Losses.log_cosh(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  0.2168903946876526
>
Link to this function

margin_ranking(y_true, arg2, opts \\ [])

View Source

Margin ranking loss function.

$$ l_i = \max(0, -\hat{y_i} * (y^(1)_i - y^(2)_i) + \alpha) $$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([1.0, 1.0, 1.0], type: {:f, 32})
iex> y_pred1 = Nx.tensor([0.6934, -0.7239,  1.1954], type: {:f, 32})
iex> y_pred2 = Nx.tensor([-0.4691, 0.2670, -1.7452], type: {:f, 32})
iex> Axon.Losses.margin_ranking(y_true, {y_pred1, y_pred2})
#Nx.Tensor<
  f32[3]
  [0.0, 0.9909000396728516, 0.0]
>

iex> y_true = Nx.tensor([1.0, 1.0, 1.0], type: {:f, 32})
iex> y_pred1 = Nx.tensor([0.6934, -0.7239,  1.1954], type: {:f, 32})
iex> y_pred2 = Nx.tensor([-0.4691, 0.2670, -1.7452], type: {:f, 32})
iex> Axon.Losses.margin_ranking(y_true, {y_pred1, y_pred2}, reduction: :mean)
#Nx.Tensor<
  f32
  0.3303000032901764
>

iex> y_true = Nx.tensor([1.0, 1.0, 1.0], type: {:f, 32})
iex> y_pred1 = Nx.tensor([0.6934, -0.7239,  1.1954], type: {:f, 32})
iex> y_pred2 = Nx.tensor([-0.4691, 0.2670, -1.7452], type: {:f, 32})
iex> Axon.Losses.margin_ranking(y_true, {y_pred1, y_pred2}, reduction: :sum)
#Nx.Tensor<
  f32
  0.9909000396728516
>
Link to this function

mean_absolute_error(y_true, y_pred, opts \\ [])

View Source

Mean-absolute error loss function.

$$ l_i = \sum_i |\hat{y_i} - y_i| $$

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_absolute_error(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.5, 0.5]
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_absolute_error(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.5
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_absolute_error(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  1.0
>
Link to this function

mean_squared_error(y_true, y_pred, opts \\ [])

View Source

Mean-squared error loss function.

$$ l_i = \sum_i (\hat{y_i} - y_i)^2 $$

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_squared_error(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.5, 0.5]
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_squared_error(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.5
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.mean_squared_error(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  1.0
>
Link to this function

poisson(y_true, y_pred, opts \\ [])

View Source

Poisson loss function.

$$ l_i = \frac{1}{C} \sum_i^C y_i - (\hat{y_i} \cdot \log(y_i)) $$

Argument Shapes

  • y_true - $(d_0, d_1, ..., d_n)$
  • y_pred - $(d_0, d_1, ..., d_n)$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.poisson(y_true, y_pred)
#Nx.Tensor<
  f32[2]
  [0.9999999403953552, 0.0]
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.poisson(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.4999999701976776
>

iex> y_true = Nx.tensor([[0.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[1.0, 1.0], [0.0, 0.0]], type: {:f, 32})
iex> Axon.Losses.poisson(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  0.9999999403953552
>
Link to this function

soft_margin(y_true, y_pred, opts \\ [])

View Source

Soft margin loss function.

$$ l_i = \sum_i \frac{\log(1 + e^{-\hat{y_i} * y_i})}{N} $$

Options

  • :reduction - reduction mode. One of :mean, :sum, or :none. Defaults to :none.

Examples

iex> y_true = Nx.tensor([[-1.0, 1.0,  1.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[0.2953, -0.1709, 0.9486]], type: {:f, 32})
iex> Axon.Losses.soft_margin(y_true, y_pred)
#Nx.Tensor<
  f32[3]
  [0.851658046245575, 0.7822436094284058, 0.3273470401763916]
>

iex> y_true = Nx.tensor([[-1.0, 1.0,  1.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[0.2953, -0.1709, 0.9486]], type: {:f, 32})
iex> Axon.Losses.soft_margin(y_true, y_pred, reduction: :mean)
#Nx.Tensor<
  f32
  0.6537495255470276
>

iex> y_true = Nx.tensor([[-1.0, 1.0,  1.0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[0.2953, -0.1709, 0.9486]], type: {:f, 32})
iex> Axon.Losses.soft_margin(y_true, y_pred, reduction: :sum)
#Nx.Tensor<
  f32
  1.9612486362457275
>