|
|
|
@ -19,15 +19,58 @@ import numpy.linalg |
|
|
|
|
from numpy import dot, trace |
|
|
|
|
from numpy.linalg import det, inv |
|
|
|
|
|
|
|
|
|
MATMUL_USE_BLAS = False |
|
|
|
|
|
|
|
|
|
def matmul(*Mats): |
|
|
|
|
def matmul(*Mats, **opts): |
|
|
|
|
"""Do successive matrix product. For example, |
|
|
|
|
matmul(A,B,C,D) |
|
|
|
|
will evaluate a matrix multiplication ((A*B)*C)*D . |
|
|
|
|
The matrices must be of matching sizes.""" |
|
|
|
|
p = numpy.dot(Mats[0], Mats[1]) |
|
|
|
|
for M in Mats[2:]: |
|
|
|
|
p = numpy.dot(p, M) |
|
|
|
|
from numpy import asarray, dot, iscomplexobj |
|
|
|
|
use_blas = opts.get('use_blas', MATMUL_USE_BLAS) |
|
|
|
|
debug = opts.get('debug', True) |
|
|
|
|
if debug: |
|
|
|
|
def dbg(msg): |
|
|
|
|
print msg, |
|
|
|
|
else: |
|
|
|
|
def dbg(msg): |
|
|
|
|
pass |
|
|
|
|
if use_blas: |
|
|
|
|
try: |
|
|
|
|
from scipy.linalg.blas import zgemm, dgemm |
|
|
|
|
except: |
|
|
|
|
# Older scipy (<= 0.10?) |
|
|
|
|
from scipy.linalg.blas import fblas |
|
|
|
|
zgemm = fblas.zgemm |
|
|
|
|
dgemm = fblas.dgemm |
|
|
|
|
|
|
|
|
|
if not use_blas: |
|
|
|
|
p = dot(Mats[0], Mats[1]) |
|
|
|
|
for M in Mats[2:]: |
|
|
|
|
p = dot(p, M) |
|
|
|
|
else: |
|
|
|
|
dbg("Using BLAS\n") |
|
|
|
|
# FIXME: Right now only supporting double precision arithmetic. |
|
|
|
|
M0 = asarray(Mats[0]) |
|
|
|
|
M1 = asarray(Mats[1]) |
|
|
|
|
if iscomplexobj(M0) or iscomplexobj(M1): |
|
|
|
|
p = zgemm(alpha=1.0, a=M0, b=M1) |
|
|
|
|
Cplx = True |
|
|
|
|
dbg("- zgemm ") |
|
|
|
|
else: |
|
|
|
|
p = dgemm(alpha=1.0, a=M0, b=M1) |
|
|
|
|
Cplx = False |
|
|
|
|
dbg("- dgemm ") |
|
|
|
|
for M in Mats[2:]: |
|
|
|
|
M2 = asarray(M) |
|
|
|
|
if Cplx or iscomplexobj(M2): |
|
|
|
|
p = zgemm(alpha=1.0, a=p, b=M2) |
|
|
|
|
Cplx = True |
|
|
|
|
dbg(" zgemm") |
|
|
|
|
else: |
|
|
|
|
p = dgemm(alpha=1.0, a=p, b=M2) |
|
|
|
|
dbg(" dgemm") |
|
|
|
|
dbg("\n") |
|
|
|
|
return p |
|
|
|
|
|
|
|
|
|
|
|
|
|
|