kompose/vendor/github.com/gonum/matrix/mat64/symmetric.go
Tomas Kral 1f8a0e06c9
Upgrade OpenShift and its dependencies.
OpenShift version 1.4.0-alpha.0
2016-10-18 12:04:00 +02:00

199 lines
4.3 KiB
Go

package mat64
import (
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)
var (
symDense *SymDense
_ Matrix = symDense
_ Symmetric = symDense
_ RawSymmetricer = symDense
)
const badSymTriangle = "mat64: blas64.Symmetric not upper"
// SymDense is a symmetric matrix that uses Dense storage.
type SymDense struct {
mat blas64.Symmetric
}
// Symmetric represents a symmetric matrix (where the element at {i, j} equals
// the element at {j, i}). Symmetric matrices are always square.
type Symmetric interface {
Matrix
// Symmetric returns the number of rows/columns in the matrix.
Symmetric() int
}
// A RawSymmetricer can return a view of itself as a BLAS Symmetric matrix.
type RawSymmetricer interface {
RawSymmetric() blas64.Symmetric
}
// NewSymDense constructs an n x n symmetric matrix. If len(mat) == n * n,
// mat will be used to hold the underlying data, or if mat == nil, new data will be allocated.
// The underlying data representation is the same as a Dense matrix, except
// the values of the entries in the lower triangular portion are completely ignored.
func NewSymDense(n int, mat []float64) *SymDense {
if n < 0 {
panic("mat64: negative dimension")
}
if mat != nil && n*n != len(mat) {
panic(ErrShape)
}
if mat == nil {
mat = make([]float64, n*n)
}
return &SymDense{blas64.Symmetric{
N: n,
Stride: n,
Data: mat,
Uplo: blas.Upper,
}}
}
func (s *SymDense) Dims() (r, c int) {
return s.mat.N, s.mat.N
}
func (s *SymDense) Symmetric() int {
return s.mat.N
}
// RawSymmetric returns the matrix as a blas64.Symmetric. The returned
// value must be stored in upper triangular format.
func (s *SymDense) RawSymmetric() blas64.Symmetric {
return s.mat
}
func (s *SymDense) isZero() bool {
return s.mat.N == 0
}
func (s *SymDense) AddSym(a, b Symmetric) {
n := a.Symmetric()
if n != b.Symmetric() {
panic(ErrShape)
}
if s.isZero() {
s.mat = blas64.Symmetric{
N: n,
Stride: n,
Data: use(s.mat.Data, n*n),
Uplo: blas.Upper,
}
} else if s.mat.N != n {
panic(ErrShape)
}
if a, ok := a.(RawSymmetricer); ok {
if b, ok := b.(RawSymmetricer); ok {
amat, bmat := a.RawSymmetric(), b.RawSymmetric()
for i := 0; i < n; i++ {
btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n]
stmp := s.mat.Data[i*s.mat.Stride+i : i*s.mat.Stride+n]
for j, v := range amat.Data[i*amat.Stride+i : i*amat.Stride+n] {
stmp[j] = v + btmp[j]
}
}
return
}
}
for i := 0; i < n; i++ {
stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
for j := i; j < n; j++ {
stmp[j] = a.At(i, j) + b.At(i, j)
}
}
}
func (s *SymDense) CopySym(a Symmetric) int {
n := a.Symmetric()
n = min(n, s.mat.N)
switch a := a.(type) {
case RawSymmetricer:
amat := a.RawSymmetric()
if amat.Uplo != blas.Upper {
panic(badSymTriangle)
}
for i := 0; i < n; i++ {
copy(s.mat.Data[i*s.mat.Stride+i:i*s.mat.Stride+n], amat.Data[i*amat.Stride+i:i*amat.Stride+n])
}
default:
for i := 0; i < n; i++ {
stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
for j := i; j < n; j++ {
stmp[j] = a.At(i, j)
}
}
}
return n
}
// SymRankOne performs a symetric rank-one update to the matrix a and stores
// the result in the receiver
// s = a + alpha * x * x'
func (s *SymDense) SymRankOne(a Symmetric, alpha float64, x []float64) {
n := s.mat.N
var w SymDense
if s == a {
w = *s
}
if w.isZero() {
w.mat = blas64.Symmetric{
N: n,
Stride: n,
Uplo: blas.Upper,
Data: use(w.mat.Data, n*n),
}
} else if n != w.mat.N {
panic(ErrShape)
}
if s != a {
w.CopySym(a)
}
if len(x) != n {
panic(ErrShape)
}
blas64.Syr(alpha, blas64.Vector{Inc: 1, Data: x}, w.mat)
*s = w
return
}
// RankTwo performs a symmmetric rank-two update to the matrix a and stores
// the result in the receiver
// m = a + alpha * (x * y' + y * x')
func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y []float64) {
n := s.mat.N
var w SymDense
if s == a {
w = *s
}
if w.isZero() {
w.mat = blas64.Symmetric{
N: n,
Stride: n,
Uplo: blas.Upper,
Data: use(w.mat.Data, n*n),
}
} else if n != w.mat.N {
panic(ErrShape)
}
if s != a {
w.CopySym(a)
}
if len(x) != n {
panic(ErrShape)
}
if len(y) != n {
panic(ErrShape)
}
blas64.Syr2(alpha, blas64.Vector{Inc: 1, Data: x}, blas64.Vector{Inc: 1, Data: y}, w.mat)
*s = w
return
}