Skip to content

Commit a7ede57

Browse files
authored
Rename transformer_ to components_ (#230)
1 parent 44fd427 commit a7ede57

19 files changed

+117
-116
lines changed

doc/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ to the following resources:
135135
arrays and outputs the learned metric score on these two points
136136
.. :math:`M = L^{\top}L` such that distance between vectors ``x`` and
137137
.. ``y`` can be computed as :math:`\sqrt{\left(x-y\right)M\left(x-y\right)}`.
138-
.. - ``transformer_from_metric(metric)``, which returns a transformation matrix
138+
.. - ``components_from_metric(metric)``, which returns a transformation matrix
139139
.. :math:`L \in \mathbb{R}^{D \times d}`, which can be used to convert a
140140
.. data matrix :math:`X \in \mathbb{R}^{n \times d}` to the
141141
.. :math:`D`-dimensional learned metric space :math:`X L^{\top}`,

metric_learn/_util.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def _check_sdp_from_eigen(w, tol=None):
375375
return True
376376

377377

378-
def transformer_from_metric(metric, tol=None):
378+
def components_from_metric(metric, tol=None):
379379
"""Returns the transformation matrix from the Mahalanobis matrix.
380380
381381
Returns the transformation matrix from the Mahalanobis matrix, i.e. the
@@ -429,10 +429,10 @@ def validate_vector(u, dtype=None):
429429
return u
430430

431431

432-
def _initialize_transformer(n_components, input, y=None, init='auto',
433-
verbose=False, random_state=None,
434-
has_classes=True):
435-
"""Returns the initial transformer to be used depending on the arguments.
432+
def _initialize_components(n_components, input, y=None, init='auto',
433+
verbose=False, random_state=None,
434+
has_classes=True):
435+
"""Returns the initial transformation to be used depending on the arguments.
436436
437437
Parameters
438438
----------
@@ -503,8 +503,8 @@ def _initialize_transformer(n_components, input, y=None, init='auto',
503503
504504
Returns
505505
-------
506-
init_transformer : `numpy.ndarray`
507-
The initial transformer to use.
506+
init_components : `numpy.ndarray`
507+
The initial transformation to use.
508508
"""
509509
# if we are doing a regression we cannot use lda:
510510
n_features = input.shape[-1]

metric_learn/base_metric.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner,
177177
178178
Attributes
179179
----------
180-
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
180+
components_ : `numpy.ndarray`, shape=(n_components, n_features)
181181
The learned linear transformation ``L``.
182182
"""
183183

@@ -243,10 +243,10 @@ def transform(self, X):
243243
X_checked = check_input(X, type_of_inputs='classic', estimator=self,
244244
preprocessor=self.preprocessor_,
245245
accept_sparse=True)
246-
return X_checked.dot(self.transformer_.T)
246+
return X_checked.dot(self.components_.T)
247247

248248
def get_metric(self):
249-
transformer_T = self.transformer_.T.copy()
249+
components_T = self.components_.T.copy()
250250

251251
def metric_fun(u, v, squared=False):
252252
"""This function computes the metric between u and v, according to the
@@ -271,7 +271,7 @@ def metric_fun(u, v, squared=False):
271271
"""
272272
u = validate_vector(u)
273273
v = validate_vector(v)
274-
transformed_diff = (u - v).dot(transformer_T)
274+
transformed_diff = (u - v).dot(components_T)
275275
dist = np.dot(transformed_diff, transformed_diff.T)
276276
if not squared:
277277
dist = np.sqrt(dist)
@@ -298,7 +298,7 @@ def get_mahalanobis_matrix(self):
298298
M : `numpy.ndarray`, shape=(n_features, n_features)
299299
The copy of the learned Mahalanobis matrix.
300300
"""
301-
return self.transformer_.T.dot(self.transformer_)
301+
return self.components_.T.dot(self.components_)
302302

303303

304304
class _PairsClassifierMixin(BaseMetricLearner):
@@ -333,7 +333,7 @@ def predict(self, pairs):
333333
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
334334
The predicted learned metric value between samples in every pair.
335335
"""
336-
check_is_fitted(self, ['threshold_', 'transformer_'])
336+
check_is_fitted(self, ['threshold_', 'components_'])
337337
return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1
338338

339339
def decision_function(self, pairs):
@@ -599,7 +599,7 @@ def predict(self, quadruplets):
599599
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
600600
Predictions of the ordering of pairs, for each quadruplet.
601601
"""
602-
check_is_fitted(self, 'transformer_')
602+
check_is_fitted(self, 'components_')
603603
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
604604
preprocessor=self.preprocessor_,
605605
estimator=self, tuple_size=self._tuple_size)

metric_learn/covariance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.base import TransformerMixin
99

1010
from .base_metric import MahalanobisMixin
11-
from ._util import transformer_from_metric
11+
from ._util import components_from_metric
1212

1313

1414
class Covariance(MahalanobisMixin, TransformerMixin):
@@ -24,9 +24,9 @@ class Covariance(MahalanobisMixin, TransformerMixin):
2424
2525
Attributes
2626
----------
27-
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
27+
components_ : `numpy.ndarray`, shape=(n_features, n_features)
2828
The linear transformation ``L`` deduced from the learned Mahalanobis
29-
metric (See function `transformer_from_metric`.)
29+
metric (See function `components_from_metric`.)
3030
3131
Examples
3232
--------
@@ -53,5 +53,5 @@ def fit(self, X, y=None):
5353
else:
5454
M = scipy.linalg.pinvh(M)
5555

56-
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
56+
self.components_ = components_from_metric(np.atleast_2d(M))
5757
return self

metric_learn/itml.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.base import TransformerMixin
1313
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
1414
from .constraints import Constraints, wrap_pairs
15-
from ._util import transformer_from_metric, _initialize_metric_mahalanobis
15+
from ._util import components_from_metric, _initialize_metric_mahalanobis
1616

1717

1818
class _BaseITML(MahalanobisMixin):
@@ -105,7 +105,7 @@ def _fit(self, pairs, y, bounds=None):
105105
print('itml converged at iter: %d, conv = %f' % (it, conv))
106106
self.n_iter_ = it
107107

108-
self.transformer_ = transformer_from_metric(A)
108+
self.components_ = components_from_metric(A)
109109
return self
110110

111111

@@ -186,9 +186,9 @@ class ITML(_BaseITML, _PairsClassifierMixin):
186186
n_iter_ : `int`
187187
The number of iterations the solver has run.
188188
189-
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
189+
components_ : `numpy.ndarray`, shape=(n_features, n_features)
190190
The linear transformation ``L`` deduced from the learned Mahalanobis
191-
metric (See function `transformer_from_metric`.)
191+
metric (See function `components_from_metric`.)
192192
193193
threshold_ : `float`
194194
If the distance metric between two points is lower than this threshold,
@@ -329,9 +329,9 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
329329
n_iter_ : `int`
330330
The number of iterations the solver has run.
331331
332-
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
332+
components_ : `numpy.ndarray`, shape=(n_features, n_features)
333333
The linear transformation ``L`` deduced from the learned Mahalanobis
334-
metric (See function `transformer_from_metric`.)
334+
metric (See function `components_from_metric`.)
335335
336336
See Also
337337
--------

metric_learn/lfda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class LFDA(MahalanobisMixin, TransformerMixin):
5151
5252
Attributes
5353
----------
54-
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
54+
components_ : `numpy.ndarray`, shape=(n_components, n_features)
5555
The learned linear transformation ``L``.
5656
5757
Examples
@@ -155,7 +155,7 @@ def fit(self, X, y):
155155
elif self.embedding_type == 'orthonormalized':
156156
vecs, _ = np.linalg.qr(vecs)
157157

158-
self.transformer_ = vecs.T
158+
self.components_ = vecs.T
159159
return self
160160

161161

metric_learn/lmnn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.metrics import euclidean_distances
1313
from sklearn.base import TransformerMixin
1414

15-
from ._util import _initialize_transformer, _check_n_components
15+
from ._util import _initialize_components, _check_n_components
1616
from .base_metric import MahalanobisMixin
1717

1818

@@ -117,7 +117,7 @@ class LMNN(MahalanobisMixin, TransformerMixin):
117117
n_iter_ : `int`
118118
The number of iterations the solver has run.
119119
120-
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
120+
components_ : `numpy.ndarray`, shape=(n_components, n_features)
121121
The learned linear transformation ``L``.
122122
123123
Examples
@@ -199,9 +199,9 @@ def fit(self, X, y):
199199
init = 'auto'
200200
else:
201201
init = self.init
202-
self.transformer_ = _initialize_transformer(output_dim, X, y, init,
203-
self.verbose,
204-
self.random_state)
202+
self.components_ = _initialize_components(output_dim, X, y, init,
203+
self.verbose,
204+
self.random_state)
205205
required_k = np.bincount(label_inds).min()
206206
if self.k > required_k:
207207
raise ValueError('not enough class labels for specified k'
@@ -226,7 +226,7 @@ def fit(self, X, y):
226226
a2[nn_idx] = np.array([])
227227

228228
# initialize L
229-
L = self.transformer_
229+
L = self.components_
230230

231231
# first iteration: we compute variables (including objective and gradient)
232232
# at initialization point
@@ -281,7 +281,7 @@ def fit(self, X, y):
281281
print("LMNN didn't converge in %d steps." % self.max_iter)
282282

283283
# store the last L
284-
self.transformer_ = L
284+
self.components_ = L
285285
self.n_iter_ = it
286286
return self
287287

metric_learn/lsml.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin
1414
from .constraints import Constraints
15-
from ._util import transformer_from_metric, _initialize_metric_mahalanobis
15+
from ._util import components_from_metric, _initialize_metric_mahalanobis
1616

1717

1818
class _BaseLSML(MahalanobisMixin):
@@ -94,7 +94,7 @@ def _fit(self, quadruplets, weights=None):
9494
print("Didn't converge after", it, "iterations. Final loss:", s_best)
9595
self.n_iter_ = it
9696

97-
self.transformer_ = transformer_from_metric(M)
97+
self.components_ = components_from_metric(M)
9898
return self
9999

100100
def _comparison_loss(self, metric, vab, vcd):
@@ -180,9 +180,9 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin):
180180
n_iter_ : `int`
181181
The number of iterations the solver has run.
182182
183-
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
183+
components_ : `numpy.ndarray`, shape=(n_features, n_features)
184184
The linear transformation ``L`` deduced from the learned Mahalanobis
185-
metric (See function `transformer_from_metric`.)
185+
metric (See function `components_from_metric`.)
186186
187187
Examples
188188
--------
@@ -294,9 +294,9 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
294294
n_iter_ : `int`
295295
The number of iterations the solver has run.
296296
297-
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
297+
components_ : `numpy.ndarray`, shape=(n_features, n_features)
298298
The linear transformation ``L`` deduced from the learned Mahalanobis
299-
metric (See function `transformer_from_metric`.)
299+
metric (See function `components_from_metric`.)
300300
"""
301301

302302
def __init__(self, tol=1e-3, max_iter=1000, prior=None,

metric_learn/mlkr.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from metric_learn._util import _check_n_components
1717
from .base_metric import MahalanobisMixin
18-
from ._util import _initialize_transformer
18+
from ._util import _initialize_components
1919

2020
EPS = np.finfo(float).eps
2121

@@ -103,7 +103,7 @@ class MLKR(MahalanobisMixin, TransformerMixin):
103103
n_iter_ : `int`
104104
The number of iterations the solver has run.
105105
106-
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
106+
components_ : `numpy.ndarray`, shape=(n_components, n_features)
107107
The learned linear transformation ``L``.
108108
109109
Examples
@@ -182,10 +182,10 @@ def fit(self, X, y):
182182
init = 'auto'
183183
else:
184184
init = self.init
185-
A = _initialize_transformer(m, X, y, init=init,
186-
random_state=self.random_state,
187-
# MLKR works on regression targets:
188-
has_classes=False)
185+
A = _initialize_components(m, X, y, init=init,
186+
random_state=self.random_state,
187+
# MLKR works on regression targets:
188+
has_classes=False)
189189

190190
# Measure the total training time
191191
train_time = time.time()
@@ -194,7 +194,7 @@ def fit(self, X, y):
194194
res = minimize(self._loss, A.ravel(), (X, y), method='L-BFGS-B',
195195
jac=True, tol=self.tol,
196196
options=dict(maxiter=self.max_iter))
197-
self.transformer_ = res.x.reshape(A.shape)
197+
self.components_ = res.x.reshape(A.shape)
198198

199199
# Stop timer
200200
train_time = time.time() - train_time

metric_learn/mmc.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .base_metric import _PairsClassifierMixin, MahalanobisMixin
1111
from .constraints import Constraints, wrap_pairs
12-
from ._util import transformer_from_metric, _initialize_metric_mahalanobis
12+
from ._util import components_from_metric, _initialize_metric_mahalanobis
1313

1414

1515
class _BaseMMC(MahalanobisMixin):
@@ -185,7 +185,7 @@ def _fit_full(self, pairs, y):
185185
self.A_[:] = A_old
186186
self.n_iter_ = cycle
187187

188-
self.transformer_ = transformer_from_metric(self.A_)
188+
self.components_ = components_from_metric(self.A_)
189189
return self
190190

191191
def _fit_diag(self, pairs, y):
@@ -246,7 +246,7 @@ def _fit_diag(self, pairs, y):
246246

247247
self.A_ = np.diag(w)
248248

249-
self.transformer_ = transformer_from_metric(self.A_)
249+
self.components_ = components_from_metric(self.A_)
250250
return self
251251

252252
def _fD(self, neg_pairs, A):
@@ -409,9 +409,9 @@ class MMC(_BaseMMC, _PairsClassifierMixin):
409409
n_iter_ : `int`
410410
The number of iterations the solver has run.
411411
412-
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
412+
components_ : `numpy.ndarray`, shape=(n_features, n_features)
413413
The linear transformation ``L`` deduced from the learned Mahalanobis
414-
metric (See function `transformer_from_metric`.)
414+
metric (See function `components_from_metric`.)
415415
416416
threshold_ : `float`
417417
If the distance metric between two points is lower than this threshold,
@@ -550,9 +550,9 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
550550
n_iter_ : `int`
551551
The number of iterations the solver has run.
552552
553-
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
553+
components_ : `numpy.ndarray`, shape=(n_features, n_features)
554554
The linear transformation ``L`` deduced from the learned Mahalanobis
555-
metric (See function `transformer_from_metric`.)
555+
metric (See function `components_from_metric`.)
556556
"""
557557

558558
def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,

0 commit comments

Comments
 (0)