LU.php 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. <?php
  2. namespace Matrix\Decomposition;
  3. use Matrix\Exception;
  4. use Matrix\Matrix;
  5. class LU
  6. {
  7. private $luMatrix;
  8. private $rows;
  9. private $columns;
  10. private $pivot = [];
  11. public function __construct(Matrix $matrix)
  12. {
  13. $this->luMatrix = $matrix->toArray();
  14. $this->rows = $matrix->rows;
  15. $this->columns = $matrix->columns;
  16. $this->buildPivot();
  17. }
  18. /**
  19. * Get lower triangular factor.
  20. *
  21. * @return Matrix Lower triangular factor
  22. */
  23. public function getL()
  24. {
  25. $lower = [];
  26. $columns = min($this->rows, $this->columns);
  27. for ($row = 0; $row < $this->rows; ++$row) {
  28. for ($column = 0; $column < $columns; ++$column) {
  29. if ($row > $column) {
  30. $lower[$row][$column] = $this->luMatrix[$row][$column];
  31. } elseif ($row === $column) {
  32. $lower[$row][$column] = 1.0;
  33. } else {
  34. $lower[$row][$column] = 0.0;
  35. }
  36. }
  37. }
  38. return new Matrix($lower);
  39. }
  40. /**
  41. * Get upper triangular factor.
  42. *
  43. * @return Matrix Upper triangular factor
  44. */
  45. public function getU()
  46. {
  47. $upper = [];
  48. $rows = min($this->rows, $this->columns);
  49. for ($row = 0; $row < $rows; ++$row) {
  50. for ($column = 0; $column < $this->columns; ++$column) {
  51. if ($row <= $column) {
  52. $upper[$row][$column] = $this->luMatrix[$row][$column];
  53. } else {
  54. $upper[$row][$column] = 0.0;
  55. }
  56. }
  57. }
  58. return new Matrix($upper);
  59. }
  60. /**
  61. * Return pivot permutation vector.
  62. *
  63. * @return Matrix Pivot matrix
  64. */
  65. public function getP()
  66. {
  67. $pMatrix = [];
  68. $pivots = $this->pivot;
  69. $pivotCount = count($pivots);
  70. foreach ($pivots as $row => $pivot) {
  71. $pMatrix[$row] = array_fill(0, $pivotCount, 0);
  72. $pMatrix[$row][$pivot] = 1;
  73. }
  74. return new Matrix($pMatrix);
  75. }
  76. /**
  77. * Return pivot permutation vector.
  78. *
  79. * @return array Pivot vector
  80. */
  81. public function getPivot()
  82. {
  83. return $this->pivot;
  84. }
  85. /**
  86. * Is the matrix nonsingular?
  87. *
  88. * @return bool true if U, and hence A, is nonsingular
  89. */
  90. public function isNonsingular()
  91. {
  92. for ($diagonal = 0; $diagonal < $this->columns; ++$diagonal) {
  93. if ($this->luMatrix[$diagonal][$diagonal] === 0.0) {
  94. return false;
  95. }
  96. }
  97. return true;
  98. }
  99. private function buildPivot()
  100. {
  101. for ($row = 0; $row < $this->rows; ++$row) {
  102. $this->pivot[$row] = $row;
  103. }
  104. for ($column = 0; $column < $this->columns; ++$column) {
  105. $luColumn = $this->localisedReferenceColumn($column);
  106. $this->applyTransformations($column, $luColumn);
  107. $pivot = $this->findPivot($column, $luColumn);
  108. if ($pivot !== $column) {
  109. $this->pivotExchange($pivot, $column);
  110. }
  111. $this->computeMultipliers($column);
  112. unset($luColumn);
  113. }
  114. }
  115. private function localisedReferenceColumn($column)
  116. {
  117. $luColumn = [];
  118. for ($row = 0; $row < $this->rows; ++$row) {
  119. $luColumn[$row] = &$this->luMatrix[$row][$column];
  120. }
  121. return $luColumn;
  122. }
  123. private function applyTransformations($column, array $luColumn)
  124. {
  125. for ($row = 0; $row < $this->rows; ++$row) {
  126. $luRow = $this->luMatrix[$row];
  127. // Most of the time is spent in the following dot product.
  128. $kmax = min($row, $column);
  129. $sValue = 0.0;
  130. for ($kValue = 0; $kValue < $kmax; ++$kValue) {
  131. $sValue += $luRow[$kValue] * $luColumn[$kValue];
  132. }
  133. $luRow[$column] = $luColumn[$row] -= $sValue;
  134. }
  135. }
  136. private function findPivot($column, array $luColumn)
  137. {
  138. $pivot = $column;
  139. for ($row = $column + 1; $row < $this->rows; ++$row) {
  140. if (abs($luColumn[$row]) > abs($luColumn[$pivot])) {
  141. $pivot = $row;
  142. }
  143. }
  144. return $pivot;
  145. }
  146. private function pivotExchange($pivot, $column)
  147. {
  148. for ($kValue = 0; $kValue < $this->columns; ++$kValue) {
  149. $tValue = $this->luMatrix[$pivot][$kValue];
  150. $this->luMatrix[$pivot][$kValue] = $this->luMatrix[$column][$kValue];
  151. $this->luMatrix[$column][$kValue] = $tValue;
  152. }
  153. $lValue = $this->pivot[$pivot];
  154. $this->pivot[$pivot] = $this->pivot[$column];
  155. $this->pivot[$column] = $lValue;
  156. }
  157. private function computeMultipliers($diagonal)
  158. {
  159. if (($diagonal < $this->rows) && ($this->luMatrix[$diagonal][$diagonal] != 0.0)) {
  160. for ($row = $diagonal + 1; $row < $this->rows; ++$row) {
  161. $this->luMatrix[$row][$diagonal] /= $this->luMatrix[$diagonal][$diagonal];
  162. }
  163. }
  164. }
  165. private function pivotB(Matrix $B)
  166. {
  167. $X = [];
  168. foreach ($this->pivot as $rowId) {
  169. $row = $B->getRows($rowId + 1)->toArray();
  170. $X[] = array_pop($row);
  171. }
  172. return $X;
  173. }
  174. /**
  175. * Solve A*X = B.
  176. *
  177. * @param Matrix $B a Matrix with as many rows as A and any number of columns
  178. *
  179. * @throws Exception
  180. *
  181. * @return Matrix X so that L*U*X = B(piv,:)
  182. */
  183. public function solve(Matrix $B)
  184. {
  185. if ($B->rows !== $this->rows) {
  186. throw new Exception('Matrix row dimensions are not equal');
  187. }
  188. if ($this->rows !== $this->columns) {
  189. throw new Exception('LU solve() only works on square matrices');
  190. }
  191. if (!$this->isNonsingular()) {
  192. throw new Exception('Can only perform operation on singular matrix');
  193. }
  194. // Copy right hand side with pivoting
  195. $nx = $B->columns;
  196. $X = $this->pivotB($B);
  197. // Solve L*Y = B(piv,:)
  198. for ($k = 0; $k < $this->columns; ++$k) {
  199. for ($i = $k + 1; $i < $this->columns; ++$i) {
  200. for ($j = 0; $j < $nx; ++$j) {
  201. $X[$i][$j] -= $X[$k][$j] * $this->luMatrix[$i][$k];
  202. }
  203. }
  204. }
  205. // Solve U*X = Y;
  206. for ($k = $this->columns - 1; $k >= 0; --$k) {
  207. for ($j = 0; $j < $nx; ++$j) {
  208. $X[$k][$j] /= $this->luMatrix[$k][$k];
  209. }
  210. for ($i = 0; $i < $k; ++$i) {
  211. for ($j = 0; $j < $nx; ++$j) {
  212. $X[$i][$j] -= $X[$k][$j] * $this->luMatrix[$i][$k];
  213. }
  214. }
  215. }
  216. return new Matrix($X);
  217. }
  218. }