Skip to content
2 changes: 1 addition & 1 deletion src/mplfinance/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_info = (0, 12, 9, 'beta', 2)
version_info = (0, 12, 9, 'beta', 3)

_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}

Expand Down
67 changes: 63 additions & 4 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,17 @@ def _valid_plot_kwargs():
'mav' : { 'Default' : None,
'Description' : 'Moving Average window size(s); (int or tuple of ints)',
'Validator' : _mav_validator },

'ema' : { 'Default' : None,
'Description' : 'Exponential Moving Average window size(s); (int or tuple of ints)',
'Validator' : _mav_validator },

'mavcolors' : { 'Default' : None,
'Description' : 'color cycle for moving averages (list or tuple of colors)'+
'(overrides mpf style mavcolors).',
'Validator' : lambda value: isinstance(value,(list,tuple)) and
all([mcolors.is_color_like(v) for v in value]) },

'renko_params' : { 'Default' : dict(),
'Description' : 'dict of renko parameters; call `mpf.kwarg_help("renko_params")`',
'Validator' : lambda value: isinstance(value,dict) },
Expand Down Expand Up @@ -450,6 +460,13 @@ def plot( data, **kwargs ):
else:
raise TypeError('style should be a `dict`; why is it not?')

if config['mavcolors'] is not None:
config['_ma_color_cycle'] = cycle(config['mavcolors'])
elif style['mavcolors'] is not None:
config['_ma_color_cycle'] = cycle(style['mavcolors'])
else:
config['_ma_color_cycle'] = None

if not external_axes_mode:
fig = plt.figure()
_adjust_figsize(fig,config)
Expand Down Expand Up @@ -528,8 +545,10 @@ def plot( data, **kwargs ):

if ptype in VALID_PMOVE_TYPES:
mavprices = _plot_mav(axA1,config,xdates,pmove_avgvals)
emaprices = _plot_ema(axA1, config, xdates, pmove_avgvals)
else:
mavprices = _plot_mav(axA1,config,xdates,closes)
emaprices = _plot_ema(axA1, config, xdates, closes)

avg_dist_between_points = (xdates[-1] - xdates[0]) / float(len(xdates))
if not config['tight_layout']:
Expand Down Expand Up @@ -595,6 +614,13 @@ def plot( data, **kwargs ):
else:
for jj in range(0,len(mav)):
retdict['mav' + str(mav[jj])] = mavprices[jj]
if config['ema'] is not None:
ema = config['ema']
if len(ema) != len(emaprices):
warnings.warn('len(ema)='+str(len(ema))+' BUT len(emaprices)='+str(len(emaprices)))
else:
for jj in range(0, len(ema)):
retdict['ema' + str(ema[jj])] = emaprices[jj]
retdict['minx'] = minx
retdict['maxx'] = maxx
retdict['miny'] = miny
Expand Down Expand Up @@ -1129,10 +1155,7 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
if len(mavgs) > 7:
mavgs = mavgs[0:7] # take at most 7

if style['mavcolors'] is not None:
mavc = cycle(style['mavcolors'])
else:
mavc = None
mavc = config['_ma_color_cycle']

for idx,mav in enumerate(mavgs):
mean = pd.Series(prices).rolling(mav).mean()
Expand All @@ -1147,6 +1170,42 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
mavp_list.append(mavprices)
return mavp_list


def _plot_ema(ax,config,xdates,prices,apmav=None,apwidth=None):
'''ema: exponential moving average'''
style = config['style']
if apmav is not None:
mavgs = apmav
else:
mavgs = config['ema']
mavp_list = []
if mavgs is not None:
shift = None
if isinstance(mavgs,dict):
shift = mavgs['shift']
mavgs = mavgs['period']
if isinstance(mavgs,int):
mavgs = mavgs, # convert to tuple
if len(mavgs) > 7:
mavgs = mavgs[0:7] # take at most 7

mavc = config['_ma_color_cycle']

for idx,mav in enumerate(mavgs):
# mean = pd.Series(prices).rolling(mav).mean()
mean = pd.Series(prices).ewm(span=mav,adjust=False).mean()
if shift is not None:
mean = mean.shift(periods=shift[idx])
emaprices = mean.values
lw = config['_width_config']['line_width']
if mavc:
ax.plot(xdates, emaprices, linewidth=lw, color=next(mavc))
else:
ax.plot(xdates, emaprices, linewidth=lw)
mavp_list.append(emaprices)
return mavp_list


def _auto_secondary_y( panels, panid, ylo, yhi ):
# If mag(nitude) for this panel is not yet set, then set it
# here, as this is the first ydata to be plotted on this panel:
Expand Down
Binary file added tests/reference_images/ema01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/reference_images/ema02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/reference_images/ema03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
124 changes: 124 additions & 0 deletions tests/test_ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import os.path
import glob
import mplfinance as mpf
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.testing.compare import compare_images

print('mpf.__version__ =',mpf.__version__) # for the record
print('mpf.__file__ =',mpf.__file__) # for the record
print("plt.rcParams['backend'] =",plt.rcParams['backend']) # for the record

base='ema'
tdir = os.path.join('tests','test_images')
refd = os.path.join('tests','reference_images')

globpattern = os.path.join(tdir,base+'*.png')
oldtestfiles = glob.glob(globpattern)
for fn in oldtestfiles:
try:
os.remove(fn)
except:
print('Error removing file "'+fn+'"')

IMGCOMP_TOLERANCE = 10.0 # this works fine for linux
# IMGCOMP_TOLERANCE = 11.0 # required for a windows pass. (really 10.25 may do it).

_df = pd.DataFrame()
def get_ema_data():
global _df
if len(_df) == 0:
_df = pd.read_csv('./examples/data/yahoofinance-GOOG-20040819-20180120.csv',
index_col='Date',parse_dates=True)
return _df


def create_ema_image(tname):

df = get_ema_data()
df = df[-50:] # show last 50 data points only

ema25 = df['Close'].ewm(span=25.0, adjust=False).mean()
mav25 = df['Close'].rolling(window=25).mean()

ap = [
mpf.make_addplot(df, panel=1, type='ohlc', color='c',
ylabel='mpf mav', mav=25, secondary_y=False),
mpf.make_addplot(ema25, panel=2, type='line', width=2, color='c',
ylabel='calculated', secondary_y=False),
mpf.make_addplot(mav25, panel=2, type='line', width=2, color='blue',
ylabel='calculated', secondary_y=False)
]

# plot and save in `tname` path
mpf.plot(df, ylabel="mpf ema", type='ohlc',
ema=25, addplot=ap, panel_ratios=(1, 1), savefig=tname
)


def test_ema01():

fname = base+'01.png'
tname = os.path.join(tdir,fname)
rname = os.path.join(refd,fname)

create_ema_image(tname)

tsize = os.path.getsize(tname)
print(glob.glob(tname),'[',tsize,'bytes',']')

rsize = os.path.getsize(rname)
print(glob.glob(rname),'[',rsize,'bytes',']')

result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
if result is not None:
print('result=',result)
assert result is None

def test_ema02():
fname = base+'02.png'
tname = os.path.join(tdir,fname)
rname = os.path.join(refd,fname)

df = get_ema_data()
df = df[-125:-35]

mpf.plot(df, type='candle', ema=(5,15,25), mav=(5,15,25), savefig=tname)

tsize = os.path.getsize(tname)
print(glob.glob(tname),'[',tsize,'bytes',']')

rsize = os.path.getsize(rname)
print(glob.glob(rname),'[',rsize,'bytes',']')

result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
if result is not None:
print('result=',result)
assert result is None

def test_ema03():
fname = base+'03.png'
tname = os.path.join(tdir,fname)
rname = os.path.join(refd,fname)

df = get_ema_data()
df = df[-125:-35]

mac = ['red','orange','yellow','green','blue','purple']

mpf.plot(df, type='candle', ema=(5,10,15,25), mav=(5,15,25),
mavcolors=mac, savefig=tname)


tsize = os.path.getsize(tname)
print(glob.glob(tname),'[',tsize,'bytes',']')

rsize = os.path.getsize(rname)
print(glob.glob(rname),'[',rsize,'bytes',']')

result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
if result is not None:
print('result=',result)
assert result is None