Skip to content

Commit a205331

Browse files
committed
Add Non-Linear Solver
0 parents commit a205331

File tree

1 file changed

+290
-0
lines changed

1 file changed

+290
-0
lines changed

NonLinearSolve.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
from __future__ import division
2+
from math import *
3+
from cmath import *
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
7+
8+
def sgn(number):
9+
# only for real part
10+
return number.real > 0
11+
12+
13+
class NonLinearSolve():
14+
def __init__(self, user_func, start, max_itr, rel_err, method):
15+
self.user_func = user_func
16+
self.start = start
17+
self.max_itr = max_itr
18+
self.rel_err = rel_err
19+
self.error_list = []
20+
self.itr_list = []
21+
self.root = None
22+
self.solvemethod = method
23+
24+
# Methods supported
25+
methods = ['Bisection', 'False_Position', 'Fixed_Point',
26+
'Newton_Raphson', 'Secant', 'Muller']
27+
28+
# make a list of safe functions
29+
safe_list = ['math', 'acos', 'asin', 'atan', 'atan2', 'ceil', 'cos', 'cosh',
30+
'degrees', 'e', 'exp', 'fabs', 'floor', 'fmod', 'frexp', 'hypot',
31+
'ldexp', 'log', 'log10', 'modf', 'pi', 'pow', 'radians', 'sin',
32+
'sinh', 'sqrt', 'tan', 'tanh', 'abs']
33+
safe_dict = dict([(k, globals().get(k, None)) for k in safe_list])
34+
35+
@staticmethod
36+
def __replace_carat__(func):
37+
return func.replace('^', '**')
38+
39+
def f(self, x, func=None):
40+
if func is None:
41+
func = self.user_func
42+
self.safe_dict['x'] = x
43+
return eval(func, {"__builtins__": None}, self.safe_dict)
44+
45+
@staticmethod
46+
def get_error(x, l):
47+
if x != 0.0:
48+
return abs((x - l) / x)
49+
else:
50+
return abs(x - l)
51+
52+
def add_to_lists(self, itr, error):
53+
self.itr_list.append(itr)
54+
self.error_list.append(error)
55+
56+
def __bracketing__(self, next):
57+
l, r = self.start
58+
error = float("inf")
59+
itr = 0
60+
prev_x = l
61+
while error > self.rel_err and itr < self.max_itr:
62+
x = next(l, r)
63+
error = NonLinearSolve.get_error(x, prev_x)
64+
if sgn(self.f(l)) == sgn(self.f(x)):
65+
l = x
66+
else:
67+
r = x
68+
# Add data into lists
69+
self.add_to_lists(itr, error)
70+
prev_x = x
71+
self.root = x
72+
itr += 1
73+
74+
def __iterative__(self, next):
75+
x_list = [i for i in self.start]
76+
error = float("inf")
77+
itr = 0
78+
while error > self.rel_err and itr < self.max_itr:
79+
x = next(x_list)
80+
error = NonLinearSolve.get_error(x, x_list[-1])
81+
# Add data into lists
82+
self.add_to_lists(itr, error)
83+
self.root = x
84+
x_list.append(x)
85+
itr += 1
86+
87+
def Bisection(self):
88+
self.__bracketing__(lambda x, y: (x + y) / 2)
89+
90+
def False_Position(self):
91+
def next(l, r):
92+
return l - (r - l) * self.f(l) / (self.f(r) - self.f(l))
93+
self.__bracketing__(next)
94+
95+
def Fixed_Point(self):
96+
phi_func = raw_input('Enter the function phi(x):')
97+
98+
def next(x_list):
99+
return self.f(x_list[-1], phi_func)
100+
self.__iterative__(next)
101+
102+
def Newton_Raphson(self):
103+
f_dash = raw_input("Enter the function f'(x):")
104+
105+
def next(x_list):
106+
x = x_list[-1]
107+
return x - self.f(x) / self.f(x, f_dash)
108+
self.__iterative__(next)
109+
110+
def Secant(self):
111+
def next(x_list):
112+
y, x = x_list[-1], x_list[-2]
113+
return y - (y - x) * (self.f(y)) / (self.f(y) - self.f(x))
114+
self.__iterative__(next)
115+
116+
def Muller(self):
117+
def next(x_list):
118+
x = x_list
119+
y = [self.f(x[i]) for i in (-3, -2, -1)]
120+
c = y[-1]
121+
a = (y[-1] - y[-3]) / (x[-1] - x[-3]) - \
122+
(y[-2] - y[-1]) / (x[-2] - x[-1])
123+
a /= (x[-3] - x[-2])
124+
b = (y[-1] - y[-2]) / (x[-1] - x[-2]) * (x[-3] - x[-1]) - \
125+
(x[-2] - x[-1]) * (y[-3] - y[-1]) / (x[-3] - x[-1])
126+
b /= (x[-3] - x[-2])
127+
det = sqrt(abs(b)**4 - 4 * a * c * (b.conjugate()**2))
128+
delx = (-2 * c * b.conjugate()) / (abs(b)**2 + det)
129+
return x[-1] + delx
130+
self.__iterative__(next)
131+
132+
@staticmethod
133+
def __format__(root):
134+
if root.imag == 0:
135+
return root.real
136+
return root
137+
138+
def get_root(self):
139+
return self.__format__(self.root)
140+
141+
def __plot_fx__(self, xvals):
142+
yvals = list(self.f(i) for i in xvals)
143+
plt.plot(xvals, yvals)
144+
plt.grid()
145+
plt.xlabel('x')
146+
plt.ylabel('f(x)')
147+
plt.title('f(x) vs x')
148+
plt.show()
149+
150+
def plot_fx(self):
151+
xvals = np.arange(self.root.real - 10, self.root.real + 10, 0.01)
152+
self.__plot_fx__(xvals)
153+
154+
def plot_error(self):
155+
plt.plot(self.itr_list, self.error_list)
156+
plt.grid()
157+
plt.xlabel('Iteration no.')
158+
plt.ylabel('Error')
159+
plt.title('Relative approximate error vs iteration number')
160+
plt.show()
161+
162+
163+
class Bairstow(NonLinearSolve):
164+
'''
165+
This method finds all roots of a polynimial
166+
This method supports polynomial only.
167+
'''
168+
methods = NonLinearSolve.methods + ['Bairstow']
169+
170+
def __init__(self, user_func, start, max_itr, rel_err, method):
171+
NonLinearSolve.__init__(self, user_func, start,
172+
max_itr, rel_err, method)
173+
self.roots_left = len(self.user_func) - 1
174+
self.poly = self.user_func
175+
176+
@staticmethod
177+
def get_error(delx, x):
178+
if x != 0.0:
179+
return delx / x
180+
else:
181+
return delx
182+
183+
def f(self, x):
184+
fx, term = 0, 1
185+
for coef in self.user_func[::-1]:
186+
fx += term * coef
187+
term *= x
188+
return fx
189+
190+
def find(self):
191+
if len(self.poly) < 3:
192+
if len(self.poly) == 2:
193+
self.root = [self.poly[1] / self.poly[0]]
194+
self.poly = []
195+
self.roots_left = 0
196+
return
197+
a = self.start
198+
error = float("inf")
199+
itr = 0
200+
while error > self.rel_err and itr < self.max_itr:
201+
d = [self.poly[0], self.poly[1] + a[1] * self.poly[0]]
202+
for i in range(len(self.poly) - 2):
203+
d.append(self.poly[i + 2] + a[1] * d[i + 1] + a[0] * d[i])
204+
del_d = [0, d[0], d[1] + a[1] * d[0]]
205+
for i in range(1, len(self.poly) - 2):
206+
del_d.append(d[i + 1] + a[1] * del_d[i + 1] + a[0] * del_d[i])
207+
matx = np.array([[del_d[-2], del_d[-1]],
208+
[del_d[-3], del_d[-2]]])
209+
vecb = np.array([-d[-1], -d[-2]])
210+
del_a = np.linalg.solve(matx, vecb)
211+
error = max(Bairstow.get_error(
212+
del_a[0], a[0]), Bairstow.get_error(del_a[1], a[1]))
213+
a[0] += del_a[0]
214+
a[1] += del_a[1]
215+
itr += 1
216+
else:
217+
self.poly = d[:-2]
218+
det = sqrt(a[1]**2 + 4 * a[0])
219+
self.root = [0.5 * (a[0] + det), 0.5 * (a[1] - det)]
220+
self.roots_left = len(self.poly) - 1
221+
222+
def update_start(self, start):
223+
self.start = start
224+
225+
def get_root(self):
226+
roots = [str(NonLinearSolve.__format__(r)) for r in self.root]
227+
return ', '.join(roots)
228+
229+
def get_roots_left(self):
230+
return self.roots_left
231+
232+
def plot_fx(self):
233+
xvals = np.arange(self.root[0].real - 10, self.root[0].real + 10, 0.01)
234+
self.__plot_fx__(xvals)
235+
236+
237+
def main():
238+
239+
print 'Press -1 to quit'
240+
for i, method in enumerate(Bairstow.methods):
241+
print 'Press ', i, 'for', method, 'Method'
242+
opt = int(raw_input())
243+
if opt == -1:
244+
return
245+
246+
if Bairstow.methods[opt] == 'Bairstow':
247+
user_func = map(float, raw_input(
248+
"Enter the polynomial coefficients:\n%s " %
249+
'[starting from nth degree to constant, separated by spaces]:'
250+
).split())
251+
else:
252+
user_func = raw_input('Enter the function f(x): ')
253+
user_func = NonLinearSolve.__replace_carat__(user_func)
254+
start = map(float, raw_input(
255+
'Enter starting values(comma separated): ').split(','))
256+
max_itr = int(raw_input('Enter the maximum no. of iterations: '))
257+
rel_err = float(
258+
raw_input('Enter the required relative approximate error %: ')) / 100.0
259+
260+
if Bairstow.methods[opt] == 'Bairstow':
261+
solver = Bairstow(user_func, start, max_itr, rel_err,
262+
Bairstow.methods[opt])
263+
while True:
264+
solver.find()
265+
print 'Roots are: ', solver.get_root()
266+
if solver.get_roots_left() > 0:
267+
print 'Number of roots left: ', solver.get_roots_left()
268+
if solver.get_roots_left() > 1:
269+
start = map(float, raw_input(
270+
'Enter starting values(comma separated): ').split(','))
271+
solver.update_start(start)
272+
else:
273+
break
274+
if raw_input('Plot f(x) vs x?[Press 1 for yes]: ') is '1':
275+
solver.plot_fx()
276+
else:
277+
solver = NonLinearSolve(user_func, start, max_itr, rel_err,
278+
NonLinearSolve.methods[opt])
279+
getattr(solver, NonLinearSolve.methods[opt])()
280+
281+
print 'Root of f(x)=0 is: ', solver.get_root()
282+
if raw_input('Plot f(x) vs x?[Press 1 for yes]: ') is '1':
283+
solver.plot_fx()
284+
if raw_input("Plot relative approximate error vs iteration number?%s: " %
285+
'[Press 1 for yes]') is '1':
286+
solver.plot_error()
287+
288+
289+
if __name__ == '__main__':
290+
main()

0 commit comments

Comments
 (0)