import logging
from functools import partial
import numpy as np
from scipy.sparse import csr_matrix
from openTSNE import kl_divergence
from openTSNE.tsne import TSNEEmbedding
log = logging.getLogger(__name__)
[docs]
class Callback:
[docs]
def optimization_about_to_start(self):
"""This is called at the beginning of the optimization procedure."""
[docs]
def __call__(self, iteration, error, embedding):
"""This is the main method called from the optimization.
Parameters
----------
iteration: int
The current iteration number.
error: float
The current KL divergence of the given embedding.
embedding: TSNEEmbedding
The current t-SNE embedding.
Returns
-------
stop_optimization: bool
If this value is set to ``True``, the optimization will be
interrupted.
"""
class VerifyExaggerationError(Callback):
"""Used to verify that the exaggeration correction implemented in
`gradient_descent` is correct."""
def __init__(self, embedding: TSNEEmbedding) -> None:
self.embedding = embedding
# Keep a copy of the unexaggerated affinity matrix
self.P = self.embedding.affinities.P.copy()
def __call__(
self, iteration: int, corrected_error: float, embedding: TSNEEmbedding
):
params = self.embedding.gradient_descent_params
method = params["negative_gradient_method"]
if np.sum(embedding.affinities.P) <= 1:
log.warning("Are you sure you are testing an exaggerated P matrix?")
if method == "fft":
f = partial(
kl_divergence.kl_divergence_approx_fft,
n_interpolation_points=params["n_interpolation_points"],
min_num_intervals=params["min_num_intervals"],
ints_in_interval=params["ints_in_interval"],
dof=params["dof"],
)
elif method == "bh":
f = partial(
kl_divergence.kl_divergence_approx_bh,
theta=params["theta"],
dof=params["dof"],
)
P = self.P
true_error = f(P.indices, P.indptr, P.data, embedding)
if abs(true_error - corrected_error) > 1e-8:
raise RuntimeError("Correction term is wrong.")
else:
log.info(
"Corrected: %.4f - True %.4f [eps %.4f]"
% (corrected_error, true_error, abs(true_error - corrected_error))
)
class ErrorApproximations(Callback):
"""Check how good the error approximations are. Of course, we use an
approximation for P so this itself is an approximation."""
def __init__(self, P: csr_matrix):
self.P = P.copy()
self.exact_errors = []
self.bh_errors = []
self.fft_errors = []
def __call__(self, iteration: int, error: float, embedding: TSNEEmbedding):
exact_error = kl_divergence.kl_divergence_exact(self.P.toarray(), embedding)
bh_error = kl_divergence.kl_divergence_approx_bh(
self.P.indices, self.P.indptr, self.P.data, embedding
)
fft_error = kl_divergence.kl_divergence_approx_fft(
self.P.indices, self.P.indptr, self.P.data, embedding
)
self.exact_errors.append(exact_error)
self.bh_errors.append(bh_error)
self.fft_errors.append(fft_error)
def report(self):
exact_errors = np.array(self.exact_errors)
bh_errors = np.array(self.bh_errors)
fft_errors = np.array(self.fft_errors)
bh_diff = bh_errors - exact_errors
print(
"Barnes-Hut: mean difference %.4f (±%.4f)"
% (np.mean(bh_diff), np.std(bh_diff))
)
fft_diff = fft_errors - exact_errors
print(
"Interpolation: mean difference %.4f (±%.4f)"
% (np.mean(fft_diff), np.std(fft_diff))
)