Coverage for colour/recovery/otsu2018.py: 100%

321 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-15 19:01 +1300

1""" 

2Otsu, Yamamoto and Hachisuka (2018) - Reflectance Recovery 

3========================================================== 

4 

5Define the objects for reflectance recovery, i.e., spectral upsampling, using 

6*Otsu et al. (2018)* method. 

7 

8- :class:`colour.recovery.Dataset_Otsu2018` 

9- :func:`colour.recovery.XYZ_to_sd_Otsu2018` 

10- :func:`colour.recovery.Tree_Otsu2018` 

11 

12References 

13---------- 

14- :cite:`Otsu2018` : Otsu, H., Yamamoto, M., & Hachisuka, T. (2018). 

15 Reproducing Spectral Reflectances From Tristimulus Colours. Computer 

16 Graphics Forum, 37(6), 370-381. doi:10.1111/cgf.13332 

17""" 

18 

19from __future__ import annotations 

20 

21import typing 

22from dataclasses import dataclass 

23 

24import numpy as np 

25 

26from colour.algebra import eigen_decomposition 

27from colour.colorimetry import ( 

28 MultiSpectralDistributions, 

29 SpectralDistribution, 

30 SpectralShape, 

31 handle_spectral_arguments, 

32 msds_to_XYZ_integration, 

33 reshape_msds, 

34 sd_to_XYZ, 

35) 

36 

37if typing.TYPE_CHECKING: 

38 from colour.hints import ( 

39 Any, 

40 ArrayLike, 

41 Callable, 

42 Dict, 

43 Domain1, 

44 PathLike, 

45 Self, 

46 Sequence, 

47 Tuple, 

48 ) 

49 

50from colour.hints import NDArrayFloat, cast 

51from colour.models import XYZ_to_xy 

52from colour.recovery import ( 

53 BASIS_FUNCTIONS_OTSU2018, 

54 CLUSTER_MEANS_OTSU2018, 

55 SELECTOR_ARRAY_OTSU2018, 

56 SPECTRAL_SHAPE_OTSU2018, 

57) 

58from colour.utilities import ( 

59 TreeNode, 

60 as_float_array, 

61 as_float_scalar, 

62 domain_range_scale, 

63 is_tqdm_installed, 

64 message_box, 

65 optional, 

66 to_domain_1, 

67 zeros, 

68) 

69 

70if is_tqdm_installed(): 

71 from tqdm import tqdm 

72else: # pragma: no cover 

73 from unittest import mock 

74 

75 tqdm = mock.MagicMock() 

76 

77__author__ = "Colour Developers" 

78__copyright__ = "Copyright 2013 Colour Developers" 

79__license__ = "BSD-3-Clause - https://opensource.org/licenses/BSD-3-Clause" 

80__maintainer__ = "Colour Developers" 

81__email__ = "colour-developers@colour-science.org" 

82__status__ = "Production" 

83 

84__all__ = [ 

85 "Dataset_Otsu2018", 

86 "DATASET_REFERENCE_OTSU2018", 

87 "XYZ_to_sd_Otsu2018", 

88 "PartitionAxis", 

89 "Data_Otsu2018", 

90 "Node_Otsu2018", 

91 "Tree_Otsu2018", 

92] 

93 

94 

95class Dataset_Otsu2018: 

96 """ 

97 Store all information required for the *Otsu et al. (2018)* spectral 

98 upsampling method. 

99 

100 Datasets can be generated and converted as a 

101 :class:`colour.recovery.Dataset_Otsu2018` class instance using the 

102 :meth:`colour.recovery.Tree_Otsu2018.to_dataset` method or loaded from 

103 disk with the :meth:`colour.recovery.Dataset_Otsu2018.read` method. 

104 

105 Parameters 

106 ---------- 

107 shape 

108 Shape of the spectral data. 

109 basis_functions 

110 Three basis functions for every cluster. 

111 means 

112 Mean for every cluster. 

113 selector_array 

114 Array describing how to select the appropriate cluster. See the 

115 :meth:`colour.recovery.Dataset_Otsu2018.select` method for details. 

116 

117 Attributes 

118 ---------- 

119 - :attr:`~colour.recovery.Dataset_Otsu2018.shape` 

120 - :attr:`~colour.recovery.Dataset_Otsu2018.basis_functions` 

121 - :attr:`~colour.recovery.Dataset_Otsu2018.means` 

122 - :attr:`~colour.recovery.Dataset_Otsu2018.selector_array` 

123 

124 Methods 

125 ------- 

126 - :meth:`~colour.recovery.Dataset_Otsu2018.__init__` 

127 - :meth:`~colour.recovery.Dataset_Otsu2018.select` 

128 - :meth:`~colour.recovery.Dataset_Otsu2018.cluster` 

129 - :meth:`~colour.recovery.Dataset_Otsu2018.read` 

130 - :meth:`~colour.recovery.Dataset_Otsu2018.write` 

131 

132 References 

133 ---------- 

134 :cite:`Otsu2018` 

135 

136 Examples 

137 -------- 

138 >>> import os 

139 >>> import colour 

140 >>> from colour.characterisation import SDS_COLOURCHECKERS 

141 >>> from colour.colorimetry import sds_and_msds_to_msds 

142 >>> reflectances = sds_and_msds_to_msds( 

143 ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

144 ... ) 

145 >>> node_tree = Tree_Otsu2018(reflectances) 

146 >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) 

147 >>> dataset = node_tree.to_dataset() 

148 >>> path = os.path.join( 

149 ... colour.__path__[0], 

150 ... "recovery", 

151 ... "tests", 

152 ... "resources", 

153 ... "ColorChecker_Otsu2018.npz", 

154 ... ) 

155 >>> dataset.write(path) # doctest: +SKIP 

156 >>> dataset = Dataset_Otsu2018() # doctest: +SKIP 

157 >>> dataset.read(path) # doctest: +SKIP 

158 """ 

159 

160 def __init__( 

161 self, 

162 shape: SpectralShape | None = None, 

163 basis_functions: NDArrayFloat | None = None, 

164 means: NDArrayFloat | None = None, 

165 selector_array: NDArrayFloat | None = None, 

166 ) -> None: 

167 self._shape: SpectralShape | None = shape 

168 self._basis_functions: NDArrayFloat | None = ( 

169 basis_functions 

170 if basis_functions is None 

171 else as_float_array(basis_functions) 

172 ) 

173 self._means: NDArrayFloat | None = ( 

174 means if means is None else as_float_array(means) 

175 ) 

176 self._selector_array: NDArrayFloat | None = ( 

177 selector_array if selector_array is None else as_float_array(selector_array) 

178 ) 

179 

180 @property 

181 def shape(self) -> SpectralShape | None: 

182 """ 

183 Getter for the spectral shape of the *Otsu et al. (2018)* dataset. 

184 

185 Returns 

186 ------- 

187 :class:`colour.SpectralShape` or :py:data:`None` 

188 Spectral shape used by the *Otsu et al. (2018)* dataset. 

189 """ 

190 

191 return self._shape 

192 

193 @property 

194 def basis_functions(self) -> NDArrayFloat | None: 

195 """ 

196 Getter for the basis functions of the *Otsu et al. (2018)* dataset. 

197 

198 Returns 

199 ------- 

200 :class:`numpy.ndarray` or :py:data:`None` 

201 Basis functions of the *Otsu et al. (2018)* dataset. 

202 """ 

203 

204 return self._basis_functions 

205 

206 @property 

207 def means(self) -> NDArrayFloat | None: 

208 """ 

209 Getter for the means of the *Otsu et al. (2018)* dataset. 

210 

211 Returns 

212 ------- 

213 :class:`numpy.ndarray` or :py:data:`None` 

214 Means of the *Otsu et al. (2018)* dataset. 

215 """ 

216 

217 return self._means 

218 

219 @property 

220 def selector_array(self) -> NDArrayFloat | None: 

221 """ 

222 Getter for the selector array of the *Otsu et al. (2018)* dataset. 

223 

224 Returns 

225 ------- 

226 :class:`numpy.ndarray` or :py:data:`None` 

227 Selector array of the *Otsu et al. (2018)* dataset. 

228 """ 

229 

230 return self._selector_array 

231 

232 def __str__(self) -> str: 

233 """ 

234 Return a formatted string representation of the dataset. 

235 

236 Returns 

237 ------- 

238 :class:`str` 

239 Formatted string representation. 

240 """ 

241 

242 if self._basis_functions is not None: 

243 return ( 

244 f"{self.__class__.__name__}" 

245 f"({self._basis_functions.shape[0]} basis functions)" 

246 ) 

247 

248 return f"{self.__class__.__name__}()" 

249 

250 def select(self, xy: ArrayLike) -> int: 

251 """ 

252 Select the cluster index for the specified *CIE xy* chromaticity 

253 coordinates. 

254 

255 Parameters 

256 ---------- 

257 xy 

258 *CIE xy* chromaticity coordinates. 

259 

260 Returns 

261 ------- 

262 :class:`int` 

263 Cluster index. 

264 

265 Raises 

266 ------ 

267 ValueError 

268 If the selector array is undefined. 

269 """ 

270 

271 xy = as_float_array(xy) 

272 

273 if self._selector_array is not None: 

274 i = 0 

275 while True: 

276 row = self._selector_array[i, :] 

277 origin, direction, lesser_index, greater_index = row 

278 

279 if xy[int(direction)] <= origin: 

280 index = int(lesser_index) 

281 else: 

282 index = int(greater_index) 

283 

284 if index < 0: 

285 i = -index 

286 else: 

287 return index 

288 else: 

289 error = 'The "selector array" is undefined!' 

290 

291 raise ValueError(error) 

292 

293 def cluster(self, xy: ArrayLike) -> Tuple[NDArrayFloat, NDArrayFloat]: 

294 """ 

295 Retrieve the basis functions and dataset mean for the specified 

296 *CIE xy* chromaticity coordinates. 

297 

298 Parameters 

299 ---------- 

300 xy 

301 *CIE xy* chromaticity coordinates. 

302 

303 Returns 

304 ------- 

305 :class:`tuple` 

306 Tuple of three basis functions and dataset mean. 

307 

308 Raises 

309 ------ 

310 ValueError 

311 If the basis functions or means are undefined. 

312 """ 

313 

314 if self._basis_functions is not None and self._means is not None: 

315 index = self.select(xy) 

316 

317 return self._basis_functions[index, :, :], self._means[index, :] 

318 

319 error = 'The "basis functions" or "means" are undefined!' 

320 

321 raise ValueError(error) 

322 

323 def read(self, path: str | PathLike) -> None: 

324 """ 

325 Read and load a dataset from an *.npz* file. 

326 

327 Parameters 

328 ---------- 

329 path 

330 File path for reading the dataset. 

331 

332 Examples 

333 -------- 

334 >>> import os 

335 >>> import colour 

336 >>> from colour.characterisation import SDS_COLOURCHECKERS 

337 >>> from colour.colorimetry import sds_and_msds_to_msds 

338 >>> reflectances = sds_and_msds_to_msds( 

339 ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

340 ... ) 

341 >>> node_tree = Tree_Otsu2018(reflectances) 

342 >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) 

343 >>> dataset = node_tree.to_dataset() 

344 >>> path = os.path.join( 

345 ... colour.__path__[0], 

346 ... "recovery", 

347 ... "tests", 

348 ... "resources", 

349 ... "ColorChecker_Otsu2018.npz", 

350 ... ) 

351 >>> dataset.write(path) # doctest: +SKIP 

352 >>> dataset = Dataset_Otsu2018() # doctest: +SKIP 

353 >>> dataset.read(path) # doctest: +SKIP 

354 """ 

355 

356 path = str(path) 

357 

358 data = np.load(path) 

359 

360 start, end, interval = data["shape"] 

361 self._shape = SpectralShape(start, end, interval) 

362 self._basis_functions = data["basis_functions"] 

363 self._means = data["means"] 

364 self._selector_array = data["selector_array"] 

365 

366 def write(self, path: str | PathLike) -> None: 

367 """ 

368 Write the dataset to an *.npz* file at the specified path. 

369 

370 Parameters 

371 ---------- 

372 path 

373 Path to the file. 

374 

375 Raises 

376 ------ 

377 ValueError 

378 If the shape is undefined. 

379 

380 Examples 

381 -------- 

382 >>> import os 

383 >>> import colour 

384 >>> from colour.characterisation import SDS_COLOURCHECKERS 

385 >>> from colour.colorimetry import sds_and_msds_to_msds 

386 >>> reflectances = sds_and_msds_to_msds( 

387 ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

388 ... ) 

389 >>> node_tree = Tree_Otsu2018(reflectances) 

390 >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) 

391 >>> dataset = node_tree.to_dataset() 

392 >>> path = os.path.join( 

393 ... colour.__path__[0], 

394 ... "recovery", 

395 ... "tests", 

396 ... "resources", 

397 ... "ColorChecker_Otsu2018.npz", 

398 ... ) 

399 >>> dataset.write(path) # doctest: +SKIP 

400 """ 

401 

402 path = str(path) 

403 

404 if self._shape is not None: 

405 np.savez( 

406 path, 

407 shape=as_float_array( 

408 [ 

409 self._shape.start, 

410 self._shape.end, 

411 self._shape.interval, 

412 ] 

413 ), 

414 basis_functions=cast("NDArrayFloat", self._basis_functions), 

415 means=cast("NDArrayFloat", self._means), 

416 selector_array=cast("NDArrayFloat", self._selector_array), 

417 ) 

418 else: 

419 error = 'The "shape" is undefined!' 

420 

421 raise ValueError(error) 

422 

423 

424DATASET_REFERENCE_OTSU2018: Dataset_Otsu2018 = Dataset_Otsu2018( 

425 SPECTRAL_SHAPE_OTSU2018, 

426 BASIS_FUNCTIONS_OTSU2018, 

427 CLUSTER_MEANS_OTSU2018, 

428 SELECTOR_ARRAY_OTSU2018, 

429) 

430""" 

431Builtin *Otsu et al. (2018)* dataset as a 

432:class:`colour.recovery.Dataset_Otsu2018` class instance, usable by 

433:func:`colour.recovery.XYZ_to_sd_Otsu2018` definition among others. 

434""" 

435 

436 

437def XYZ_to_sd_Otsu2018( 

438 XYZ: Domain1, 

439 cmfs: MultiSpectralDistributions | None = None, 

440 illuminant: SpectralDistribution | None = None, 

441 dataset: Dataset_Otsu2018 = DATASET_REFERENCE_OTSU2018, 

442 clip: bool = True, 

443) -> SpectralDistribution: 

444 """ 

445 Recover the spectral distribution of the specified *CIE XYZ* tristimulus 

446 values using *Otsu et al. (2018)* method. 

447 

448 Parameters 

449 ---------- 

450 XYZ 

451 *CIE XYZ* tristimulus values to recover the spectral distribution 

452 from. 

453 cmfs 

454 Standard observer colour matching functions, default to the 

455 *CIE 1931 2 Degree Standard Observer*. 

456 illuminant 

457 Illuminant spectral distribution, default to 

458 *CIE Standard Illuminant D65*. 

459 dataset 

460 Dataset to use for reconstruction. The default is to use the 

461 published data. 

462 clip 

463 If *True*, the default, values below zero and above unity in the 

464 recovered spectral distributions will be clipped. This ensures that 

465 the returned reflectance is physical and conserves energy, but will 

466 cause noticeable colour differences in case of very saturated 

467 colours. 

468 

469 Returns 

470 ------- 

471 :class:`colour.SpectralDistribution` 

472 Recovered spectral distribution. Its shape is always that of the 

473 :class:`colour.recovery.SPECTRAL_SHAPE_OTSU2018` class instance. 

474 

475 Raises 

476 ------ 

477 ValueError 

478 If the dataset shape is undefined. 

479 

480 References 

481 ---------- 

482 :cite:`Otsu2018` 

483 

484 Notes 

485 ----- 

486 +------------+-----------------------+---------------+ 

487 | **Domain** | **Scale - Reference** | **Scale - 1** | 

488 +============+=======================+===============+ 

489 | ``XYZ`` | 1 | 1 | 

490 +------------+-----------------------+---------------+ 

491 

492 Examples 

493 -------- 

494 >>> from colour import ( 

495 ... CCS_ILLUMINANTS, 

496 ... SDS_ILLUMINANTS, 

497 ... MSDS_CMFS, 

498 ... XYZ_to_sRGB, 

499 ... ) 

500 >>> from colour.colorimetry import sd_to_XYZ_integration 

501 >>> from colour.utilities import numpy_print_options 

502 >>> XYZ = np.array([0.20654008, 0.12197225, 0.05136952]) 

503 >>> cmfs = ( 

504 ... MSDS_CMFS["CIE 1931 2 Degree Standard Observer"] 

505 ... .copy() 

506 ... .align(SPECTRAL_SHAPE_OTSU2018) 

507 ... ) 

508 >>> illuminant = SDS_ILLUMINANTS["D65"].copy().align(cmfs.shape) 

509 >>> sd = XYZ_to_sd_Otsu2018(XYZ, cmfs, illuminant) 

510 >>> with numpy_print_options(suppress=True): 

511 ... sd # doctest: +ELLIPSIS 

512 SpectralDistribution([[ 380. , 0.0601939...], 

513 [ 390. , 0.0568063...], 

514 [ 400. , 0.0517429...], 

515 [ 410. , 0.0495841...], 

516 [ 420. , 0.0502007...], 

517 [ 430. , 0.0506489...], 

518 [ 440. , 0.0510020...], 

519 [ 450. , 0.0493782...], 

520 [ 460. , 0.0468046...], 

521 [ 470. , 0.0437132...], 

522 [ 480. , 0.0416957...], 

523 [ 490. , 0.0403783...], 

524 [ 500. , 0.0405197...], 

525 [ 510. , 0.0406031...], 

526 [ 520. , 0.0416912...], 

527 [ 530. , 0.0430956...], 

528 [ 540. , 0.0444474...], 

529 [ 550. , 0.0459336...], 

530 [ 560. , 0.0507631...], 

531 [ 570. , 0.0628967...], 

532 [ 580. , 0.0844661...], 

533 [ 590. , 0.1334277...], 

534 [ 600. , 0.2262428...], 

535 [ 610. , 0.3599330...], 

536 [ 620. , 0.4885571...], 

537 [ 630. , 0.5752546...], 

538 [ 640. , 0.6193023...], 

539 [ 650. , 0.6450744...], 

540 [ 660. , 0.6610548...], 

541 [ 670. , 0.6688673...], 

542 [ 680. , 0.6795426...], 

543 [ 690. , 0.6887933...], 

544 [ 700. , 0.7003469...], 

545 [ 710. , 0.7084128...], 

546 [ 720. , 0.7154674...], 

547 [ 730. , 0.7234334...]], 

548 SpragueInterpolator, 

549 {}, 

550 Extrapolator, 

551 {'method': 'Constant', 'left': None, 'right': None}) 

552 >>> sd_to_XYZ_integration(sd, cmfs, illuminant) / 100 # doctest: +ELLIPSIS 

553 array([ 0.2065494..., 0.1219712..., 0.0514002...]) 

554 """ 

555 

556 shape = dataset.shape 

557 if shape is not None: 

558 XYZ = to_domain_1(XYZ) 

559 

560 cmfs, illuminant = handle_spectral_arguments( 

561 cmfs, illuminant, shape_default=SPECTRAL_SHAPE_OTSU2018 

562 ) 

563 

564 xy = XYZ_to_xy(XYZ) 

565 

566 basis_functions, mean = dataset.cluster(xy) 

567 

568 M = np.empty((3, 3)) 

569 for i in range(3): 

570 sd = SpectralDistribution(basis_functions[i, :], shape.wavelengths) 

571 

572 with domain_range_scale("ignore"): 

573 M[:, i] = sd_to_XYZ(sd, cmfs, illuminant) / 100 

574 

575 M_inverse = np.linalg.inv(M) 

576 

577 sd = SpectralDistribution(mean, shape.wavelengths) 

578 

579 with domain_range_scale("ignore"): 

580 XYZ_mu = sd_to_XYZ(sd, cmfs, illuminant) / 100 

581 

582 weights = np.dot(M_inverse, XYZ - XYZ_mu) 

583 recovered_sd = np.dot(weights, basis_functions) + mean 

584 

585 recovered_sd = np.clip(recovered_sd, 0, 1) if clip else recovered_sd 

586 

587 return SpectralDistribution(recovered_sd, shape.wavelengths) 

588 

589 error = 'The dataset "shape" is undefined!' 

590 

591 raise ValueError(error) 

592 

593 

594@dataclass 

595class PartitionAxis: 

596 """ 

597 Represent a horizontal or vertical line that partitions 2D space into 

598 two half-planes. 

599 

600 Parameters 

601 ---------- 

602 origin 

603 X-coordinate of a vertical line or Y-coordinate of a horizontal line. 

604 direction 

605 Direction indicator: *0* for vertical, *1* for horizontal. 

606 

607 Methods 

608 ------- 

609 - :meth:`~colour.recovery.otsu2018.PartitionAxis.__str__` 

610 """ 

611 

612 origin: float 

613 direction: int 

614 

615 def __str__(self) -> str: 

616 """ 

617 Return a formatted string representation of the partition axis. 

618 

619 Returns 

620 ------- 

621 :class:`str` 

622 Formatted string representation. 

623 """ 

624 

625 return ( 

626 f"{self.__class__.__name__}" 

627 f"({'horizontal' if self.direction else 'vertical'} partition " 

628 f"at {'y' if self.direction else 'x'} = {self.origin})" 

629 ) 

630 

631 

632class Data_Otsu2018: 

633 """ 

634 Store reference reflectances and derived information, and provide methods 

635 to process them for a leaf :class:`colour.recovery.otsu2018.Node` class 

636 instance. 

637 

638 Support partitioning by creating two smaller instances of 

639 :class:`colour.recovery.otsu2018.Data` through splitting along a 

640 horizontal or vertical axis on the *CIE xy* chromaticity plane. 

641 

642 Parameters 

643 ---------- 

644 reflectances 

645 Reference reflectances of the *n* colours to be stored. 

646 The shape must match ``tree.shape`` with *m* points for each colour. 

647 cmfs 

648 Standard observer colour matching functions. 

649 illuminant 

650 Illuminant spectral distribution. 

651 

652 Attributes 

653 ---------- 

654 - :attr:`~colour.recovery.otsu2018.Data.reflectances` 

655 - :attr:`~colour.recovery.otsu2018.Data.cmfs` 

656 - :attr:`~colour.recovery.otsu2018.Data.illuminant` 

657 - :attr:`~colour.recovery.otsu2018.Data.basis_functions` 

658 - :attr:`~colour.recovery.otsu2018.Data.mean` 

659 

660 Methods 

661 ------- 

662 - :meth:`~colour.recovery.otsu2018.Data.__str__` 

663 - :meth:`~colour.recovery.otsu2018.Data.__len__` 

664 - :meth:`~colour.recovery.otsu2018.Data.origin` 

665 - :meth:`~colour.recovery.otsu2018.Data.partition` 

666 - :meth:`~colour.recovery.otsu2018.Data.PCA` 

667 - :meth:`~colour.recovery.otsu2018.Data.reconstruct` 

668 - :meth:`~colour.recovery.otsu2018.Data.reconstruction_error` 

669 """ 

670 

671 def __init__( 

672 self, 

673 reflectances: ArrayLike | None, 

674 cmfs: MultiSpectralDistributions, 

675 illuminant: SpectralDistribution, 

676 ) -> None: 

677 self._cmfs: MultiSpectralDistributions = cmfs 

678 self._illuminant: SpectralDistribution = illuminant 

679 

680 self._XYZ: NDArrayFloat | None = None 

681 self._xy: NDArrayFloat | None = None 

682 

683 self._reflectances: NDArrayFloat | None = np.array([]) 

684 self.reflectances = reflectances 

685 

686 self._basis_functions: NDArrayFloat | None = None 

687 self._mean: NDArrayFloat | None = None 

688 self._M: NDArrayFloat | None = None 

689 self._XYZ_mu: NDArrayFloat | None = None 

690 

691 self._reconstruction_error: float | None = None 

692 

693 @property 

694 def reflectances(self) -> NDArrayFloat | None: 

695 """ 

696 Getter and setter for the reference reflectances. 

697 

698 Parameters 

699 ---------- 

700 value 

701 Value to set the reference reflectances with. 

702 

703 Returns 

704 ------- 

705 :class:`numpy.ndarray` 

706 Reference reflectances. 

707 """ 

708 

709 return self._reflectances 

710 

711 @reflectances.setter 

712 def reflectances(self, value: ArrayLike | None) -> None: 

713 """Setter for the **self.reflectances** property.""" 

714 

715 if value is not None: 

716 self._reflectances = as_float_array(value) 

717 self._XYZ = ( 

718 msds_to_XYZ_integration( 

719 self._reflectances, 

720 self._cmfs, 

721 self._illuminant, 

722 shape=self._cmfs.shape, 

723 ) 

724 / 100 

725 ) 

726 self._xy = XYZ_to_xy(self._XYZ) 

727 else: 

728 self._reflectances, self._XYZ, self._xy = None, None, None 

729 

730 @property 

731 def cmfs(self) -> MultiSpectralDistributions: 

732 """ 

733 Getter for the standard observer colour matching functions. 

734 

735 Returns 

736 ------- 

737 :class:`colour.MultiSpectralDistributions` 

738 Standard observer colour matching functions. 

739 """ 

740 

741 return self._cmfs 

742 

743 @property 

744 def illuminant(self) -> SpectralDistribution: 

745 """ 

746 Getter for the illuminant spectral distribution. 

747 

748 Returns 

749 ------- 

750 :class:`colour.SpectralDistribution` 

751 Illuminant spectral distribution. 

752 """ 

753 

754 return self._illuminant 

755 

756 @property 

757 def basis_functions(self) -> NDArrayFloat | None: 

758 """ 

759 Getter for the basis functions. 

760 

761 Returns 

762 ------- 

763 :class:`numpy.ndarray` 

764 Basis functions used for spectral representation. 

765 """ 

766 

767 return self._basis_functions 

768 

769 @property 

770 def mean(self) -> NDArrayFloat | None: 

771 """ 

772 Getter for the mean distribution of the basis functions. 

773 

774 Returns 

775 ------- 

776 :class:`numpy.ndarray` or :py:data:`None` 

777 Mean distribution representing the average values across the 

778 basis functions, or :py:data:`None` if no mean has been 

779 computed or specified. 

780 """ 

781 

782 return self._mean 

783 

784 def __str__(self) -> str: 

785 """ 

786 Return a formatted string representation of the data. 

787 

788 Returns 

789 ------- 

790 :class:`str` 

791 Formatted string representation. 

792 """ 

793 

794 return f"{self.__class__.__name__}({len(self)} Reflectances)" 

795 

796 def __len__(self) -> int: 

797 """ 

798 Return the number of colours in the data. 

799 

800 Returns 

801 ------- 

802 :class:`int` 

803 Number of colours in the data. 

804 """ 

805 

806 return self._reflectances.shape[0] if self._reflectances is not None else 0 

807 

808 def origin(self, i: int, direction: int) -> float: 

809 """ 

810 Retrieve the origin *CIE x* or *CIE y* chromaticity coordinate for 

811 the specified index and direction. 

812 

813 Parameters 

814 ---------- 

815 i 

816 Origin index. 

817 direction 

818 Origin direction. 

819 

820 Returns 

821 ------- 

822 :class:`float` 

823 Origin *CIE x* or *CIE y* chromaticity coordinate. 

824 

825 Raises 

826 ------ 

827 ValueError 

828 If the chromaticity coordinates are undefined. 

829 """ 

830 

831 if self._xy is not None: 

832 return self._xy[i, direction] 

833 

834 error = 'The "chromaticity coordinates" are undefined!' 

835 

836 raise ValueError(error) 

837 

838 def partition(self, axis: PartitionAxis) -> Tuple[Data_Otsu2018, Data_Otsu2018]: 

839 """ 

840 Partition the data using the specified partition axis. 

841 

842 Parameters 

843 ---------- 

844 axis 

845 Partition axis used to partition the data. 

846 

847 Returns 

848 ------- 

849 :class:`tuple` 

850 Tuple of left or lower part and right or upper part. 

851 

852 Raises 

853 ------ 

854 ValueError 

855 If the tristimulus values or chromaticity coordinates are 

856 undefined. 

857 """ 

858 

859 lesser = Data_Otsu2018(None, self._cmfs, self._illuminant) 

860 greater = Data_Otsu2018(None, self._cmfs, self._illuminant) 

861 

862 if ( 

863 self._XYZ is not None 

864 and self._xy is not None 

865 and self._reflectances is not None 

866 ): 

867 mask = self._xy[:, axis.direction] <= axis.origin 

868 

869 lesser._reflectances = self._reflectances[mask, :] 

870 greater._reflectances = self._reflectances[~mask, :] 

871 

872 lesser._XYZ = self._XYZ[mask, :] 

873 greater._XYZ = self._XYZ[~mask, :] 

874 

875 lesser._xy = self._xy[mask, :] 

876 greater._xy = self._xy[~mask, :] 

877 

878 return lesser, greater 

879 

880 error = 'The "tristimulus values" or "chromaticity coordinates" are undefined!' 

881 

882 raise ValueError(error) 

883 

884 def PCA(self) -> None: 

885 """ 

886 Perform *Principal Component Analysis* (PCA) on the data and set the 

887 relevant attributes accordingly. 

888 """ 

889 

890 if self._M is None and self._reflectances is not None: 

891 settings: Dict[str, Any] = { 

892 "cmfs": self._cmfs, 

893 "illuminant": self._illuminant, 

894 "shape": self._cmfs.shape, 

895 } 

896 

897 self._mean = np.mean(self._reflectances, axis=0) 

898 self._XYZ_mu = ( 

899 msds_to_XYZ_integration(cast("NDArrayFloat", self._mean), **settings) 

900 / 100 

901 ) 

902 

903 _w, w = eigen_decomposition( 

904 self._reflectances - self._mean, # pyright: ignore 

905 descending_order=False, 

906 covariance_matrix=True, 

907 ) 

908 self._basis_functions = np.transpose(w[:, -3:]) 

909 

910 self._M = np.transpose( 

911 msds_to_XYZ_integration(self._basis_functions, **settings) / 100 

912 ) 

913 

914 def reconstruct(self, XYZ: ArrayLike) -> SpectralDistribution: 

915 """ 

916 Reconstruct the reflectance for the specified *CIE XYZ* tristimulus 

917 values. 

918 

919 Parameters 

920 ---------- 

921 XYZ 

922 *CIE XYZ* tristimulus values to recover the spectral 

923 distribution from. 

924 

925 Returns 

926 ------- 

927 :class:`colour.SpectralDistribution` 

928 Recovered spectral distribution. 

929 

930 Raises 

931 ------ 

932 ValueError 

933 If the matrix :math:`M`, the mean tristimulus values or the 

934 basis functions are undefined. 

935 """ 

936 

937 if ( 

938 self._M is not None 

939 and self._XYZ_mu is not None 

940 and self._basis_functions is not None 

941 ): 

942 XYZ = as_float_array(XYZ) 

943 

944 weights = np.dot(np.linalg.inv(self._M), XYZ - self._XYZ_mu) 

945 reflectance = np.dot(weights, self._basis_functions) + self._mean 

946 reflectance = np.clip(reflectance, 0, 1) 

947 

948 return SpectralDistribution(reflectance, self._cmfs.wavelengths) 

949 

950 error = ( 

951 'The matrix "M", the "mean tristimulus values" or the ' 

952 '"basis functions" are undefined!' 

953 ) 

954 

955 raise ValueError(error) 

956 

957 def reconstruction_error(self) -> float: 

958 """ 

959 Compute the reconstruction error of the data. 

960 

961 The error is computed by reconstructing the reflectances for the 

962 reference *CIE XYZ* tristimulus values using PCA and comparing the 

963 reconstructed reflectances against the reference reflectances. 

964 

965 Returns 

966 ------- 

967 :class:`float` 

968 Reconstruction error for the data. 

969 

970 Raises 

971 ------ 

972 ValueError 

973 If the tristimulus values are undefined. 

974 

975 Notes 

976 ----- 

977 - The reconstruction error is cached upon being computed and thus 

978 is only computed once per node. 

979 """ 

980 

981 if self._reconstruction_error is not None: 

982 return self._reconstruction_error 

983 

984 if self._XYZ is not None and self._reflectances is not None: 

985 self.PCA() 

986 

987 reconstruction_error: float = 0.0 

988 for i in range(len(self)): 

989 sd = self._reflectances[i, :] 

990 XYZ = self._XYZ[i, :] 

991 recovered_sd = self.reconstruct(XYZ) 

992 reconstruction_error += cast( 

993 "float", np.sum((sd - recovered_sd.values) ** 2) 

994 ) 

995 

996 self._reconstruction_error = reconstruction_error 

997 

998 return reconstruction_error 

999 

1000 error = 'The "tristimulus values" are undefined!' 

1001 

1002 raise ValueError(error) 

1003 

1004 

1005class Node_Otsu2018(TreeNode): 

1006 """ 

1007 Represent a node in a :meth:`colour.recovery.Tree_Otsu2018` class 

1008 instance node tree. 

1009 

1010 Parameters 

1011 ---------- 

1012 parent 

1013 Parent of the node. 

1014 children 

1015 Children of the node. 

1016 data 

1017 The colour data belonging to this node. 

1018 

1019 Attributes 

1020 ---------- 

1021 - :attr:`~colour.recovery.otsu2018.Node.partition_axis` 

1022 - :attr:`~colour.recovery.otsu2018.Node.row` 

1023 

1024 Methods 

1025 ------- 

1026 - :meth:`~colour.recovery.otsu2018.Node.__init__` 

1027 - :meth:`~colour.recovery.otsu2018.Node.split` 

1028 - :meth:`~colour.recovery.otsu2018.Node.minimise` 

1029 - :meth:`~colour.recovery.otsu2018.Node.leaf_reconstruction_error` 

1030 - :meth:`~colour.recovery.otsu2018.Node.branch_reconstruction_error` 

1031 """ 

1032 

1033 def __init__( 

1034 self, 

1035 parent: Self | None = None, 

1036 children: list | None = None, 

1037 data: Data_Otsu2018 | None = None, 

1038 ) -> None: 

1039 super().__init__(parent=parent, children=children, data=data) 

1040 

1041 self._partition_axis: PartitionAxis | None = None 

1042 self._best_partition: ( 

1043 Tuple[Sequence[Node_Otsu2018], PartitionAxis, float] | None 

1044 ) = None 

1045 

1046 @property 

1047 def partition_axis(self) -> PartitionAxis | None: 

1048 """ 

1049 Getter for the node partition axis. 

1050 

1051 Returns 

1052 ------- 

1053 :class:`colour.recovery.otsu2018.PartitionAxis` 

1054 Node partition axis. 

1055 """ 

1056 

1057 return self._partition_axis 

1058 

1059 @property 

1060 def row(self) -> Tuple[float, float, Self, Self]: 

1061 """ 

1062 Getter for the node row of the selector array. 

1063 

1064 Returns 

1065 ------- 

1066 :class:`tuple` 

1067 Node row for the selector array. 

1068 

1069 Raises 

1070 ------ 

1071 ValueError 

1072 If the partition axis is undefined. 

1073 """ 

1074 

1075 if self._partition_axis is not None: 

1076 return ( 

1077 self._partition_axis.origin, 

1078 self._partition_axis.direction, 

1079 self.children[0], 

1080 self.children[1], 

1081 ) 

1082 

1083 error = 'The "partition axis" is undefined!' 

1084 

1085 raise ValueError(error) 

1086 

1087 def split(self, children: Sequence[Self], axis: PartitionAxis) -> None: 

1088 """ 

1089 Convert the leaf node into an inner node using the specified children and 

1090 partition axis. 

1091 

1092 Parameters 

1093 ---------- 

1094 children 

1095 Tuple of two :class:`colour.recovery.otsu2018.Node` class 

1096 instances. 

1097 axis 

1098 Partition axis. 

1099 """ 

1100 

1101 self.data = None 

1102 self.children = list(children) 

1103 

1104 self._best_partition = None 

1105 self._partition_axis = axis 

1106 

1107 def minimise( 

1108 self, minimum_cluster_size: int 

1109 ) -> Tuple[Sequence[Node_Otsu2018], PartitionAxis, float]: 

1110 """ 

1111 Minimise the leaf reconstruction error by finding the best partition 

1112 for the node. 

1113 

1114 Parameters 

1115 ---------- 

1116 minimum_cluster_size 

1117 Smallest acceptable cluster size. Must be at least 3 to enable 

1118 *Principal Component Analysis* (PCA). 

1119 

1120 Returns 

1121 ------- 

1122 :class:`tuple` 

1123 Tuple containing nodes created by splitting this node with the 

1124 optimal partition, the partition axis (horizontal or vertical 

1125 line partitioning the 2D space into two half-planes), and the 

1126 partition error. 

1127 """ 

1128 

1129 if self._best_partition is not None: 

1130 return self._best_partition 

1131 

1132 leaf_error = self.leaf_reconstruction_error() 

1133 best_error = None 

1134 

1135 with tqdm(total=2 * len(self.data)) as progress: 

1136 for direction in [0, 1]: 

1137 for i in range(len(self.data)): 

1138 progress.update() 

1139 

1140 axis = PartitionAxis(self.data.origin(i, direction), direction) 

1141 data_lesser, data_greater = self.data.partition(axis) 

1142 

1143 if np.any( 

1144 np.array( 

1145 [ 

1146 len(data_lesser), 

1147 len(data_greater), 

1148 ] 

1149 ) 

1150 < minimum_cluster_size 

1151 ): 

1152 continue 

1153 

1154 lesser = Node_Otsu2018(data=data_lesser) 

1155 lesser.data.PCA() 

1156 

1157 greater = Node_Otsu2018(data=data_greater) 

1158 greater.data.PCA() 

1159 

1160 partition_error = ( 

1161 lesser.leaf_reconstruction_error() 

1162 + greater.leaf_reconstruction_error() 

1163 ) 

1164 

1165 partition = [lesser, greater] 

1166 

1167 if partition_error >= leaf_error: 

1168 continue 

1169 

1170 if best_error is None or partition_error < best_error: 

1171 self._best_partition = ( 

1172 partition, 

1173 axis, 

1174 partition_error, 

1175 ) 

1176 

1177 if self._best_partition is None: 

1178 error = "Could not find the best partition!" 

1179 

1180 raise RuntimeError(error) 

1181 

1182 return self._best_partition 

1183 

1184 def leaf_reconstruction_error(self) -> float: 

1185 """ 

1186 Compute the reconstruction error of the node data. 

1187 

1188 The error is computed by reconstructing the reflectances for the data 

1189 reference *CIE XYZ* tristimulus values using PCA and comparing the 

1190 reconstructed reflectances against the data reference reflectances. 

1191 

1192 Returns 

1193 ------- 

1194 :class:`float` 

1195 Reconstruction errors summation for the node data. 

1196 """ 

1197 

1198 return self.data.reconstruction_error() 

1199 

1200 def branch_reconstruction_error(self) -> float: 

1201 """ 

1202 Compute the reconstruction error for all leaves data connected to the 

1203 node or its children. 

1204 

1205 The reconstruction error is the summation of errors for all leaves in 

1206 the branch. 

1207 

1208 Returns 

1209 ------- 

1210 :class:`float` 

1211 Summation of reconstruction errors for all leaves data in the 

1212 branch. 

1213 """ 

1214 

1215 if self.is_leaf(): 

1216 return self.leaf_reconstruction_error() 

1217 

1218 return as_float_scalar( 

1219 np.sum([child.branch_reconstruction_error() for child in self.children]) 

1220 ) 

1221 

1222 

1223class Tree_Otsu2018(Node_Otsu2018): 

1224 """ 

1225 Sub-class of :class:`colour.recovery.otsu2018.Node` representing the 

1226 root node of a tree containing information shared with all nodes, such 

1227 as the standard observer colour matching functions and the illuminant, 

1228 if any is used. 

1229 

1230 Implement global operations involving the entire tree, such as 

1231 optimisation and conversion to dataset. 

1232 

1233 Parameters 

1234 ---------- 

1235 reflectances 

1236 Reference reflectances of the *n* reference colours to use for 

1237 optimisation. 

1238 cmfs 

1239 Standard observer colour matching functions, default to the 

1240 *CIE 1931 2 Degree Standard Observer*. 

1241 illuminant 

1242 Illuminant spectral distribution, default to 

1243 *CIE Standard Illuminant D65*. 

1244 

1245 Attributes 

1246 ---------- 

1247 - :attr:`~colour.recovery.Tree_Otsu2018.reflectances` 

1248 - :attr:`~colour.recovery.Tree_Otsu2018.cmfs` 

1249 - :attr:`~colour.recovery.Tree_Otsu2018.illuminant` 

1250 

1251 Methods 

1252 ------- 

1253 - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.__init__` 

1254 - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.__str__` 

1255 - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.optimise` 

1256 - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.to_dataset` 

1257 

1258 References 

1259 ---------- 

1260 :cite:`Otsu2018` 

1261 

1262 Examples 

1263 -------- 

1264 >>> import os 

1265 >>> import colour 

1266 >>> from colour import MSDS_CMFS, SDS_COLOURCHECKERS, SDS_ILLUMINANTS 

1267 >>> from colour.colorimetry import sds_and_msds_to_msds 

1268 >>> from colour.utilities import numpy_print_options 

1269 >>> XYZ = np.array([0.20654008, 0.12197225, 0.05136952]) 

1270 >>> cmfs = ( 

1271 ... MSDS_CMFS["CIE 1931 2 Degree Standard Observer"] 

1272 ... .copy() 

1273 ... .align(SpectralShape(360, 780, 10)) 

1274 ... ) 

1275 >>> illuminant = SDS_ILLUMINANTS["D65"].copy().align(cmfs.shape) 

1276 >>> reflectances = sds_and_msds_to_msds( 

1277 ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

1278 ... ) 

1279 >>> node_tree = Tree_Otsu2018(reflectances, cmfs, illuminant) 

1280 >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) 

1281 >>> dataset = node_tree.to_dataset() 

1282 >>> path = os.path.join( 

1283 ... colour.__path__[0], 

1284 ... "recovery", 

1285 ... "tests", 

1286 ... "resources", 

1287 ... "ColorChecker_Otsu2018.npz", 

1288 ... ) 

1289 >>> dataset.write(path) # doctest: +SKIP 

1290 >>> dataset = Dataset_Otsu2018() # doctest: +SKIP 

1291 >>> dataset.read(path) # doctest: +SKIP 

1292 >>> sd = XYZ_to_sd_Otsu2018(XYZ, cmfs, illuminant, dataset) 

1293 >>> with numpy_print_options(suppress=True): 

1294 ... sd # doctest: +ELLIPSIS 

1295 SpectralDistribution([[ 360. , 0.0651341...], 

1296 [ 370. , 0.0651341...], 

1297 [ 380. , 0.0651341...], 

1298 [ 390. , 0.0749684...], 

1299 [ 400. , 0.0815578...], 

1300 [ 410. , 0.0776439...], 

1301 [ 420. , 0.0721897...], 

1302 [ 430. , 0.0649064...], 

1303 [ 440. , 0.0567185...], 

1304 [ 450. , 0.0484685...], 

1305 [ 460. , 0.0409768...], 

1306 [ 470. , 0.0358964...], 

1307 [ 480. , 0.0307857...], 

1308 [ 490. , 0.0270148...], 

1309 [ 500. , 0.0273773...], 

1310 [ 510. , 0.0303157...], 

1311 [ 520. , 0.0331285...], 

1312 [ 530. , 0.0363027...], 

1313 [ 540. , 0.0425987...], 

1314 [ 550. , 0.0513442...], 

1315 [ 560. , 0.0579256...], 

1316 [ 570. , 0.0653850...], 

1317 [ 580. , 0.0929522...], 

1318 [ 590. , 0.1600326...], 

1319 [ 600. , 0.2586159...], 

1320 [ 610. , 0.3701242...], 

1321 [ 620. , 0.4702243...], 

1322 [ 630. , 0.5396261...], 

1323 [ 640. , 0.5737561...], 

1324 [ 650. , 0.590848 ...], 

1325 [ 660. , 0.5935371...], 

1326 [ 670. , 0.5923295...], 

1327 [ 680. , 0.5956326...], 

1328 [ 690. , 0.5982513...], 

1329 [ 700. , 0.6017904...], 

1330 [ 710. , 0.6016419...], 

1331 [ 720. , 0.5996892...], 

1332 [ 730. , 0.6000018...], 

1333 [ 740. , 0.5964443...], 

1334 [ 750. , 0.5868181...], 

1335 [ 760. , 0.5860973...], 

1336 [ 770. , 0.5614878...], 

1337 [ 780. , 0.5289331...]], 

1338 SpragueInterpolator, 

1339 {}, 

1340 Extrapolator, 

1341 {'method': 'Constant', 'left': None, 'right': None}) 

1342 """ 

1343 

1344 def __init__( 

1345 self, 

1346 reflectances: MultiSpectralDistributions, 

1347 cmfs: MultiSpectralDistributions | None = None, 

1348 illuminant: SpectralDistribution | None = None, 

1349 ) -> None: 

1350 super().__init__() 

1351 

1352 cmfs, illuminant = handle_spectral_arguments( 

1353 cmfs, illuminant, shape_default=SPECTRAL_SHAPE_OTSU2018 

1354 ) 

1355 

1356 self._cmfs: MultiSpectralDistributions = cmfs 

1357 self._illuminant: SpectralDistribution = illuminant 

1358 

1359 self._reflectances: NDArrayFloat = np.transpose( 

1360 reshape_msds(reflectances, self._cmfs.shape, copy=False).values 

1361 ) 

1362 

1363 self.data: Data_Otsu2018 = Data_Otsu2018( 

1364 self._reflectances, self._cmfs, self._illuminant 

1365 ) 

1366 

1367 @property 

1368 def reflectances(self) -> NDArrayFloat: 

1369 """ 

1370 Getter for the reference reflectances. 

1371 

1372 Returns 

1373 ------- 

1374 :class:`numpy.ndarray` 

1375 Reference reflectances. 

1376 """ 

1377 

1378 return self._reflectances 

1379 

1380 @property 

1381 def cmfs(self) -> MultiSpectralDistributions: 

1382 """ 

1383 Getter for the standard observer colour matching functions. 

1384 

1385 Returns 

1386 ------- 

1387 :class:`colour.MultiSpectralDistributions` 

1388 Standard observer colour matching functions. 

1389 """ 

1390 

1391 return self._cmfs 

1392 

1393 @property 

1394 def illuminant(self) -> SpectralDistribution: 

1395 """ 

1396 Getter for the test illuminant. 

1397 

1398 Returns 

1399 ------- 

1400 :class:`colour.SpectralDistribution` 

1401 Test illuminant spectral distribution. 

1402 """ 

1403 

1404 return self._illuminant 

1405 

1406 def optimise( 

1407 self, 

1408 iterations: int = 8, 

1409 minimum_cluster_size: int | None = None, 

1410 print_callable: Callable = print, 

1411 ) -> None: 

1412 """ 

1413 Optimise the tree by repeatedly performing optimal partitioning of 

1414 nodes, creating a tree that minimises the total reconstruction error. 

1415 

1416 Parameters 

1417 ---------- 

1418 iterations 

1419 Maximum number of splits. If the dataset is too small, this 

1420 number might not be reached. The default is to create 8 clusters, 

1421 as described in :cite:`Otsu2018`. 

1422 minimum_cluster_size 

1423 Smallest acceptable cluster size. By default, it is chosen 

1424 automatically based on the dataset size and desired number of 

1425 clusters. It must be at least 3 or *Principal Component Analysis* 

1426 (PCA) will not be possible. 

1427 print_callable 

1428 Callable used to print progress and diagnostic information. 

1429 

1430 Examples 

1431 -------- 

1432 >>> from colour.colorimetry import sds_and_msds_to_msds 

1433 >>> from colour import MSDS_CMFS, SDS_COLOURCHECKERS, SDS_ILLUMINANTS 

1434 >>> cmfs = ( 

1435 ... MSDS_CMFS["CIE 1931 2 Degree Standard Observer"] 

1436 ... .copy() 

1437 ... .align(SpectralShape(360, 780, 10)) 

1438 ... ) 

1439 >>> illuminant = SDS_ILLUMINANTS["D65"].copy().align(cmfs.shape) 

1440 >>> reflectances = sds_and_msds_to_msds( 

1441 ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

1442 ... ) 

1443 >>> node_tree = Tree_Otsu2018(reflectances, cmfs, illuminant) 

1444 >>> node_tree.optimise(iterations=2) # doctest: +ELLIPSIS 

1445 ======================================================================\ 

1446========= 

1447 * \ 

1448 * 

1449 * "Otsu et al. (2018)" Tree Optimisation \ 

1450 * 

1451 * \ 

1452 * 

1453 ======================================================================\ 

1454========= 

1455 Initial branch error is: 4.8705353... 

1456 <BLANKLINE> 

1457 Iteration 1 of 2: 

1458 <BLANKLINE> 

1459 Optimising "Tree_Otsu2018#...(Data_Otsu2018(24 Reflectances))"... 

1460 <BLANKLINE> 

1461 Splitting "Tree_Otsu2018#...(Data_Otsu2018(24 Reflectances))" into \ 

1462"Node_Otsu2018#...(Data_Otsu2018(10 Reflectances))" and \ 

1463"Node_Otsu2018#...(Data_Otsu2018(14 Reflectances))" along \ 

1464"PartitionAxis(horizontal partition at y = 0.3240945...)". 

1465 Error is reduced by 0.0054840... and is now 4.8650513..., 99.9% of \ 

1466the initial error. 

1467 <BLANKLINE> 

1468 Iteration 2 of 2: 

1469 <BLANKLINE> 

1470 Optimising "Node_Otsu2018#...(Data_Otsu2018(10 Reflectances))"... 

1471 Optimisation failed: Could not find the best partition! 

1472 Optimising "Node_Otsu2018#...(Data_Otsu2018(14 Reflectances))"... 

1473 <BLANKLINE> 

1474 Splitting "Node_Otsu2018#...(Data_Otsu2018(14 Reflectances))" into \ 

1475"Node_Otsu2018#...(Data_Otsu2018(7 Reflectances))" and \ 

1476"Node_Otsu2018#...(Data_Otsu2018(7 Reflectances))" along \ 

1477"PartitionAxis(horizontal partition at y = 0.3600663...)". 

1478 Error is reduced by 0.9681059... and is now 3.8969453..., 80.0% of \ 

1479the initial error. 

1480 Tree optimisation is complete! 

1481 >>> print(node_tree.render()) # doctest: +ELLIPSIS 

1482 |----"Tree_Otsu2018#..." 

1483 |----"Node_Otsu2018#..." 

1484 |----"Node_Otsu2018#..." 

1485 |----"Node_Otsu2018#..." 

1486 |----"Node_Otsu2018#..." 

1487 <BLANKLINE> 

1488 >>> len(node_tree) 

1489 4 

1490 """ 

1491 

1492 default_cluster_size = len(self.data) / iterations // 2 

1493 minimum_cluster_size = max( 

1494 cast("int", optional(minimum_cluster_size, default_cluster_size)), 3 

1495 ) 

1496 

1497 initial_branch_error = self.branch_reconstruction_error() 

1498 

1499 message_box( 

1500 '"Otsu et al. (2018)" Tree Optimisation', 

1501 print_callable=print_callable, 

1502 ) 

1503 

1504 print_callable(f"Initial branch error is: {initial_branch_error}") 

1505 

1506 best_leaf, best_partition, best_axis, partition_error = [None] * 4 

1507 

1508 for i in range(iterations): 

1509 print_callable(f"\nIteration {i + 1} of {iterations}:\n") 

1510 

1511 total_error = self.branch_reconstruction_error() 

1512 optimised_total_error = None 

1513 

1514 for leaf in self.leaves: 

1515 print_callable(f'Optimising "{leaf}"...') 

1516 

1517 try: 

1518 partition, axis, partition_error = leaf.minimise( 

1519 minimum_cluster_size 

1520 ) 

1521 except RuntimeError as error: 

1522 print_callable(f"Optimisation failed: {error}") 

1523 continue 

1524 

1525 new_total_error = ( 

1526 total_error - leaf.leaf_reconstruction_error() + partition_error 

1527 ) 

1528 

1529 if ( 

1530 optimised_total_error is None 

1531 or new_total_error < optimised_total_error 

1532 ): 

1533 optimised_total_error = new_total_error 

1534 best_axis = axis 

1535 best_leaf = leaf 

1536 best_partition = partition 

1537 

1538 if optimised_total_error is None: 

1539 print_callable( 

1540 f"\nNo further improvement is possible!" 

1541 f"\nTerminating at iteration {i}.\n" 

1542 ) 

1543 break 

1544 

1545 if best_partition is not None: 

1546 print_callable( 

1547 f'\nSplitting "{best_leaf}" into "{best_partition[0]}" ' 

1548 f'and "{best_partition[1]}" along "{best_axis}".' 

1549 ) 

1550 

1551 print_callable( 

1552 f"Error is reduced by " 

1553 f"{leaf.leaf_reconstruction_error() - partition_error} and " 

1554 f"is now {optimised_total_error}, " 

1555 f"{100 * optimised_total_error / initial_branch_error:.1f}% " 

1556 f"of the initial error." 

1557 ) 

1558 

1559 if best_leaf is not None: 

1560 best_leaf.split(best_partition, best_axis) 

1561 

1562 print_callable("Tree optimisation is complete!") 

1563 

1564 def to_dataset(self) -> Dataset_Otsu2018: 

1565 """ 

1566 Create a :class:`colour.recovery.Dataset_Otsu2018` class instance 

1567 based on data stored in the tree. 

1568 

1569 The dataset can then be saved to disk or used to recover reflectance 

1570 with the :func:`colour.recovery.XYZ_to_sd_Otsu2018` definition. 

1571 

1572 Returns 

1573 ------- 

1574 :class:`colour.recovery.Dataset_Otsu2018` 

1575 Dataset object. 

1576 

1577 Examples 

1578 -------- 

1579 >>> from colour.colorimetry import sds_and_msds_to_msds 

1580 >>> from colour.characterisation import SDS_COLOURCHECKERS 

1581 >>> reflectances = sds_and_msds_to_msds( 

1582 ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

1583 ... ) 

1584 >>> node_tree = Tree_Otsu2018(reflectances) 

1585 >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) 

1586 >>> node_tree.to_dataset() # doctest: +ELLIPSIS 

1587 <colour.recovery.otsu2018.Dataset_Otsu2018 object at 0x...> 

1588 """ 

1589 

1590 basis_functions = as_float_array( 

1591 [leaf.data.basis_functions for leaf in self.leaves] 

1592 ) 

1593 

1594 means = as_float_array([leaf.data.mean for leaf in self.leaves]) 

1595 

1596 if len(self.children) == 0: 

1597 selector_array = zeros(4) 

1598 else: 

1599 

1600 def add_rows(node: Node_Otsu2018, data: dict | None = None) -> dict | None: 

1601 """Add rows for the specified node and its children.""" 

1602 

1603 data = optional(data, {"rows": [], "node_to_leaf_id": {}, "leaf_id": 0}) 

1604 

1605 if node.is_leaf(): 

1606 data["node_to_leaf_id"][node] = data["leaf_id"] 

1607 data["leaf_id"] += 1 

1608 return None 

1609 

1610 data["node_to_leaf_id"][node] = -len(data["rows"]) 

1611 data["rows"].append(list(node.row)) 

1612 

1613 for child in node.children: 

1614 add_rows(child, data) 

1615 

1616 return data 

1617 

1618 data = cast("dict", add_rows(self)) 

1619 rows = data["rows"] 

1620 

1621 for i, row in enumerate(rows): 

1622 for j in (2, 3): 

1623 rows[i][j] = data["node_to_leaf_id"][row[j]] 

1624 

1625 selector_array = as_float_array(rows) 

1626 

1627 return Dataset_Otsu2018( 

1628 self._cmfs.shape, 

1629 basis_functions, 

1630 means, 

1631 selector_array, 

1632 )