r"""
.. _apps-training-tutorial:

Training variational GBS distributions
======================================

*Technical details are available in the API documentation:*
:doc:`code/api/strawberryfields.apps.train`

Many quantum algorithms rely on the ability to train the parameters of quantum circuits, a
strategy inspired by the success of neural networks in machine learning. Training is often
performed by evaluating gradients of a cost function with respect to circuit parameters,
then employing gradient-based optimization methods. In this demonstration, we outline the
theoretical principles for training Gaussian Boson Sampling (GBS) circuits, first
introduced in Ref. [[#banchi2020training]_]. We then explain how to employ Strawberry Fields
to perform the training by looking at basic examples in stochastic optimization and unsupervised
learning. Let's go! 🚀

Theory
------

As explained in more detail in :doc:`/concepts/gbs`, for a GBS device, the probability
:math:`\Pr(S)` of observing an output :math:`S=(s_1, s_2, \ldots, s_m)`, where :math:`s_i` is the
number of photons detected in the :math:`i`-th mode, is given by

.. math::

    \Pr(S) = \frac{1}{\mathcal{N}} \frac{|\text{Haf}(A_S)|^2}{
    s_1!\ldots s_m!},

where :math:`\mathcal{N}` is a normalization constant and :math:`A` is an arbitrary symmetric
matrix with eigenvalues bounded between :math:`-1` and :math:`1`. The matrix
:math:`A` can also be rescaled by a constant factor, which is equivalent to fixing a total
mean photon number in the distribution.

We want to *train* this distribution to perform a specific task. For example, we may wish to
reproduce the statistical properties of a given dataset, or to optimize the circuit to sample
specific patterns with high probability. The strategy is to identify a parametrization of
the distribution and compute gradients that allow us to optimize the parameters using
gradient-based techniques. We refer to these circuits as variational GBS circuits, or VGBS for short.

Gradient formulas can be generally challenging to calculate, but there is a particular
strategy known as the WAW parametrization (wow! what a great name) that leads to gradients that
are simpler to compute. It involves transforming the symmetric matrix :math:`A` as

.. math::

    A \rightarrow A_W = W A W,

where :math:`W = \text{diag}(\sqrt{w_1}, \sqrt{w_2}, \ldots, \sqrt{w_m})` is a diagonal weight matrix and :math:`m`
is the number of modes. This parametrization is useful because the hafnian of :math:`A_W`
factorizes into two separate components

.. math::

    \text{Haf}(A_W) = \text{Haf}(A)\text{det}(W),

a property that can be cleverly exploited to compute gradients more efficiently. More broadly,
it is convenient to embed trainable parameters :math:`\theta = (\theta_1, \ldots, \theta_d)` into
the weights :math:`w_k`. Several choices are possible, but here we focus on an exponential embedding

.. math::

    w_k = \exp(-\theta^T f^{(k)}),

where each :math:`f^{(k)}` is a :math:`d`-dimensional vector. The simplest case occurs by setting
:math:`d=m` and choosing these vectors to satisfy :math:`\theta^T f^{(k)} = \theta_k` such that
:math:`w_k = \exp(-\theta_k)`.

Training tasks
^^^^^^^^^^^^^^

In stochastic optimization, we are given a function :math:`h(S)` and the goal is optimize the 
parameters to sample from a distribution :math:`P_{\theta}(S)` that minimizes the 
expectation value

.. math::

        C (\theta) = \sum_{S} h(S) P_{\theta}(S).
        
As shown in [[#banchi2020training]_], the gradient of the cost function :math:`C (\theta)` is given
by

.. math::

    \partial_{\theta} C (\theta) = \sum_{S} h(S) P_{\theta}(S)
    \sum_{k=1}^{m}  (s_k - \langle s_{k} \rangle) \partial_{\theta} \log w_{k},

where :math:`\langle s_k\rangle` denotes the average photon number in mode :math:`k`. This gradient
is an
expectation value with respect to the GBS distribution, so it can be estimated by generating
samples from the device. Convenient!

In a standard unsupervised learning scenario, data are assumed to be sampled from an unknown
distribution and a common goal is to learn that distribution. More precisely, the goal is
to use the data to train a model that can sample from a similar distribution, thus being able to
generate new data. Training can be performed by minimizing the Kullback-Leibler (KL) divergence,
which up to additive constants can be written as:

.. math::

    KL(\theta) = -\frac{1}{T}\sum_S \log[P_{\theta}(S)].

In this case :math:`S` is an element of the data, :math:`P(S)` is the probability of observing that
element when sampling from the GBS distribution, and :math:`T` is the total number of elements
in the data. For the GBS distribution in the WAW parametrization, the gradient of the KL
divergence can be written as

.. math::

    \partial_\theta KL(\theta) = - \sum_{k=1}^m\frac{1}{w_k}(\langle s_k\rangle_{\text{data}}-
    \langle s_k\rangle_{\text{GBS}})\partial_\theta w_k.

Remarkably, this gradient can be evaluated without a quantum computer because the
mean photon numbers :math:`\langle s_k\rangle_{\text{GBS}}` over the GBS distribution can be
efficiently computed classically [[#banchi2020training]_]. As we'll show later, this leads to
very fast training. This is true even if sampling the distribution remains classically
intractable! 🤯

Stochastic optimization
-----------------------
We're ready to start using Strawberry Fields to train GBS distributions. The main functions needed
can be found in the :mod:`~strawberryfields.apps.train` module, so let's start by importing it.
"""

from strawberryfields.apps import train

##############################################################################
# We explore a basic example where the goal is to optimize the distribution to favour photons
# appearing in a specific subset of modes, while minimizing the number of photons in the
# remaining modes. This can be achieved with the following cost function

import numpy as np

def h(s):
    not_subset = [k for k in range(len(s)) if k not in subset]
    return sum(s[not_subset]) - sum(s[subset])


##############################################################################
# The function is defined with respect to the subset of modes for which we want to
# observe many photons. This subset can be specified later on. Then, for a given sample ``s``,
# we want to *maximize* the total number of photons in the subset. This can be achieved by
# minimizing its negative value, hence the term ``-sum(s[subset])``. Similarly, for modes
# outside of the specified subset, we want to minimize their total sum, which explains the
# appearance of ``sum(s[not_subset])``.
#
# It's now time to define the variational circuit. We'll train a distribution based on
# a simple lollipop 🍭 graph with five nodes:

import networkx as nx
from strawberryfields.apps import plot

graph = nx.lollipop_graph(3, 2)
A = nx.to_numpy_array(graph)
plot.graph(graph)

##############################################################################
# Defining a variational GBS circuit consists of three steps: (i) specifying the embedding,
# (ii) building the circuit, and (iii) defining the cost function with respect to the circuit and
# embedding. We'll go through each step one at a time. For the embedding of trainable parameters,
# we'll use the simple form :math:`w_k = \exp(-\theta_k)` outlined above, which can be accessed
# through the :class:`~strawberryfields.apps.train.Exp` class. Its only input is the number of
# modes in the device, which is equal to the number of nodes in the graph.

nr_modes = len(A)
weights = train.Exp(nr_modes)

##############################################################################
# Easy! The GBS distribution is determined by the symmetric matrix :math:`A` --- which we
# train using the WAW parametrization --- and by the total mean photon number. There is freedom
# in choosing :math:`A`, but here we'll just use the graph's adjacency matrix. The total
# mean photon number is a hyperparameter of the distribution: in general, different choices may
# lead to different results in training. In fact, the mean photon number may change during
# training as a consequence of the weights being optimized. Finally, GBS devices can operate either
# with photon number-resolving detectors or threshold detectors, so there is an option to specify
# which one we intend to use. We'll stick to detectors that can count photons.

n_mean = 6
vgbs = train.VGBS(A, n_mean, weights, threshold=False)

##############################################################################
# The last step before training is to define the cost function with respect to our
# previous choices. Since this is a stochastic optimization task, we employ the
# :class:`~strawberryfields.apps.train.Stochastic` class and input our previously defined cost
# function ``h``.

cost = train.Stochastic(h, vgbs)

##############################################################################
# During training, we'll calculate gradients and evaluate the average of this cost function with
# respect to the GBS distribution. Both of these actions require estimating expectation values, so
# the number of samples in the estimation also needs to be specified. The parameters also need to be
# initialized. There is freedom in this choice, but here we'll set them all to zero. Finally, we'll
# aim to increase the number of photons in the "candy" part of the lollipop graph,
# which corresponds to the subset of modes ``[0, 1, 2]``.

np.random.seed(1969)  # for reproducibility
d = nr_modes
params = np.zeros(d)
subset = [0, 1, 2]
nr_samples = 100

print('Initial mean photon numbers = ', vgbs.mean_photons_by_mode(params))

##############################################################################
# If training is successful, we should see the mean photon numbers of the first three modes
# increasing, while those of the last two modes become close to zero. We perform training over
# 200 steps of gradient descent with a learning rate of 0.01:

nr_steps = 200
rate = 0.01

for i in range(nr_steps):
    params -= rate * cost.grad(params, nr_samples)
    if i % 50 == 0:
        print('Cost = {:.3f}'.format(cost.evaluate(params, nr_samples)))

print('Final mean photon numbers = ', vgbs.mean_photons_by_mode(params))

##############################################################################
# Great! The cost function decreases smoothly and there is a clear increase in the mean photon
# numbers of the target modes, with a corresponding decrease in the remaining modes.
#
# The transformed matrix :math:`A_W = W A W` also needs to have eigenvalues bounded between
# -1 and 1, so continuing training indefinitely can lead to unphysical distributions when
# the weights become too large. It's important to monitor this behaviour. We can confirm that the
# trained model is behaving according to plan by generating some samples. Although we still
# observe a couple of photons in the last two modes, most of the detections happen in the first
# three modes that we are targeting, just as intended.

Aw = vgbs.A(params)
samples = vgbs.generate_samples(Aw, n_samples=10)
print(samples)

##############################################################################
# Unsupervised learning
# ---------------------
# We are going to train a circuit based on  the pre-generated datasets in
# the :mod:`~strawberryfields.apps.data` module. Gradients in this setting can be calculated
# efficiently, so we can study a larger graph with 30 nodes.
#

from strawberryfields.apps import data

data_pl = data.Planted()
data_samples = data_pl[:1000]  # we use only the first one thousand samples
A = data_pl.adj
nr_modes = len(A)

##############################################################################
# As before, we use the exponential embedding, but this time we define the VGBS circuit in
# threshold mode, since this is how the data samples were generated. The cost function is the
# Kullback-Liebler divergence, which depends on the data samples and can be accessed
# using :class:`~strawberryfields.apps.train.KL`:

weights = train.Exp(nr_modes)
n_mean = 1
vgbs = train.VGBS(A, n_mean, weights, threshold=True)
cost = train.KL(data_samples, vgbs)

##############################################################################
# We initialize parameters to zero and perform a longer optimization over one thousand
# steps with a learning rate of 0.15. This will allow us to reach a highly-trained model.
# Gradients can be computed efficiently, but evaluating the cost function is challenging
# because it requires calculating GBS probabilities, which generally take exponential time.
# Instead, we'll keep track of the differences in mean photon numbers per mode for the data and
# model distributions.

from numpy.linalg import norm

params = np.zeros(nr_modes)
steps = 1000
rate = 0.15

for i in range(steps):
    params -= rate * cost.grad(params)
    if i % 100 == 0:
        diff = cost.mean_n_data - vgbs.mean_clicks_by_mode(params)
        print('Norm of difference = {:.5f}'.format(norm(diff)))

##############################################################################
# Wow! WAW! We reach almost perfect agreement between the data and the trained model. We can also
# generate a few samples, this time creating photon patterns that, in general,
# are not originally in the training data.

Aw = vgbs.A(params)
samples = vgbs.generate_samples(Aw, n_samples=10)
print(samples)

##############################################################################
# The examples we have covered are introductory tasks aimed at mastering the basics of training
# variational GBS circuits. These ideas are new and there is much left to explore in terms of
# their scope and extensions. For example, the original paper [[#banchi2020training]_] studies how
# GBS devices can be trained to find solutions to max clique problems. What new applications come
# to your mind?
#
# References
# ----------
#
# .. [#banchi2020training]
#
#     Leonardo Banchi, Nicolás Quesada, and Juan Miguel Arrazola. Training Gaussian Boson
#     Sampling Distributions. arXiv:2004.04770. 2020.
