|
|
|
@ -1,4 +1,4 @@ |
|
|
|
|
# $Id: fitting.py,v 1.1 2010-01-22 18:44:59 wirawan Exp $ |
|
|
|
|
# $Id: fitting.py,v 1.2 2010-05-28 18:43:39 wirawan Exp $ |
|
|
|
|
# |
|
|
|
|
# wpylib.math.fitting module |
|
|
|
|
# Created: 20100120 |
|
|
|
@ -118,18 +118,27 @@ class Poly_order4(Poly_base): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fit_func(Funct, Data=None, Guess=None, x=None, y=None): |
|
|
|
|
''' |
|
|
|
|
def fit_func(Funct, Data=None, Guess=None, x=None, y=None, |
|
|
|
|
debug=10, |
|
|
|
|
method='leastsq', opts={}): |
|
|
|
|
""" |
|
|
|
|
Performs a function fitting. |
|
|
|
|
The domain of the function is a D-dimensional vector, and the function |
|
|
|
|
yields a scalar. |
|
|
|
|
|
|
|
|
|
Funct is a python function (or any callable object) with argument list of |
|
|
|
|
(C, x), where: |
|
|
|
|
* C is the cofficients (parameters) to be adjusted by the fitting process |
|
|
|
|
* C is the cofficients (parameters) being adjusted by the fitting process |
|
|
|
|
(it is a sequence or a 1-D array) |
|
|
|
|
* x is a 2-D array (or sequence of like nature). The "row" size is the dimensionality |
|
|
|
|
of the domain, while the "column" is the number of data points, whose count must be |
|
|
|
|
equal to the size of y data below. |
|
|
|
|
* x is a 2-D array (or sequence of like nature), say, |
|
|
|
|
of size "N rows times M columns". |
|
|
|
|
N is the dimensionality of the domain, while |
|
|
|
|
M is the number of data points, whose count must be equal to the |
|
|
|
|
size of y data below. |
|
|
|
|
For a 2-D fitting, for example, x should be a column array. |
|
|
|
|
|
|
|
|
|
Inspect Poly_base, Poly_order2, and other similar function classes in this module |
|
|
|
|
to see the example of the Funct function. |
|
|
|
|
Inspect Poly_base, Poly_order2, and other similar function classes in this |
|
|
|
|
module to see the example of the Funct function. |
|
|
|
|
|
|
|
|
|
The measurement (input) datasets, against which the function is to be fitted, |
|
|
|
|
can be specified in one of two ways: |
|
|
|
@ -139,11 +148,10 @@ def fit_func(Funct, Data=None, Guess=None, x=None, y=None): |
|
|
|
|
Or, |
|
|
|
|
* via Data argument (which is a multi-column dataset |
|
|
|
|
|
|
|
|
|
''' |
|
|
|
|
""" |
|
|
|
|
global last_fit_rslt, last_chi_sqr |
|
|
|
|
from scipy.optimize import leastsq |
|
|
|
|
from scipy.optimize import leastsq, anneal |
|
|
|
|
# We want to minimize this error: |
|
|
|
|
fun_err = lambda CC, xx, yy: abs(Funct(CC,xx) - yy) |
|
|
|
|
if Data != None: # an alternative way to specifying x and y |
|
|
|
|
y = Data[0] |
|
|
|
|
x = Data[1:] # possibly multidimensional! |
|
|
|
@ -152,12 +160,36 @@ def fit_func(Funct, Data=None, Guess=None, x=None, y=None): |
|
|
|
|
Guess = Funct.Guess(y) |
|
|
|
|
elif Guess == None: # VERY OLD, DO NOT USE ANYMORE! |
|
|
|
|
Guess = [ y.mean() ] + [0.0, 0.0] * len(x) |
|
|
|
|
rslt = leastsq(fun_err, |
|
|
|
|
x0=Guess, # initial coefficient guess |
|
|
|
|
args=(x,y), # data onto which the function is fitted |
|
|
|
|
full_output=1) |
|
|
|
|
|
|
|
|
|
fun_err = lambda CC, xx, yy: abs(Funct(CC,xx) - yy) |
|
|
|
|
fun_err2 = lambda CC, xx, yy: numpy.sum(abs(Funct(CC,xx) - yy)**2) |
|
|
|
|
|
|
|
|
|
if debug >= 5: |
|
|
|
|
print "Guess params:" |
|
|
|
|
print Guess |
|
|
|
|
|
|
|
|
|
if method == 'leastsq': |
|
|
|
|
rslt = leastsq(fun_err, |
|
|
|
|
x0=Guess, # initial coefficient guess |
|
|
|
|
args=(x,y), # data onto which the function is fitted |
|
|
|
|
full_output=1, |
|
|
|
|
**opts |
|
|
|
|
) |
|
|
|
|
elif method == 'anneal': |
|
|
|
|
rslt = anneal(fun_err2, |
|
|
|
|
x0=Guess, # initial coefficient guess |
|
|
|
|
args=(x,y), # data onto which the function is fitted |
|
|
|
|
full_output=1, |
|
|
|
|
**opts |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
raise ValueError, "Unsupported minimization method: %s" % method |
|
|
|
|
last_fit_rslt = rslt |
|
|
|
|
last_chi_sqr = sum( fun_err(rslt[0], x, y)**2 ) |
|
|
|
|
last_chi_sqr = fun_err2(rslt[0], x, y) |
|
|
|
|
if (debug >= 10): |
|
|
|
|
#print "Fit-message: ", rslt[] |
|
|
|
|
print "Fit-result:" |
|
|
|
|
print "\n".join([ "%2d %s" % (ii, rslt[ii]) for ii in xrange(len(rslt)) ]) |
|
|
|
|
print "params = ", rslt[0] |
|
|
|
|
print "chi square = ", last_chi_sqr / len(y) |
|
|
|
|
return rslt[0] |
|
|
|
|