pandafeed.py、blaze.py、influxDB、btcsv.py、mt4csv.py 五个对接数据文件的源代码注释,后续基于对 backtrader 源代码的研究会重新梳理下 backtrader 的结构,让大家更好的理解 backtrader 和使用,目前源代码注释和阅读都是一个比较枯燥的过程,不感兴趣可以忽略,这部分内容的忽略不怎么影响大家使用 backtrader.
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from backtrader.utils.py3 import filter, string_types, integer_types
from backtrader import date2num
import backtrader.feed as feed
# backtrader 通过 pandas 加载数据
class PandasDirectData(feed.DataBase):
'''
Uses a Pandas DataFrame as the feed source, iterating directly over the
tuples returned by "itertuples".
This means that all parameters related to lines must have numeric
values as indices into the tuples
Note:
- The ``dataname`` parameter is a Pandas DataFrame
- A negative value in any of the parameters for the Data lines
indicates it's not present in the DataFrame
it is
'''
# 参数
params = (
('datetime', 0),
('open', 1),
('high', 2),
('low', 3),
('close', 4),
('volume', 5),
('openinterest', 6),
)
# 列名
datafields = [
'datetime', 'open', 'high', 'low', 'close', 'volume', 'openinterest'
]
# 开始,把 dataframe 数据转化成可以迭代的元组,每一行一个元组
def start(self):
super(PandasDirectData, self).start()
# reset the iterator on each start
self._rows = self.p.dataname.itertuples()
def _load(self):
# 尝试获取下一个 row,如果获取不到报错,就返回 False
try:
row = next(self._rows)
except StopIteration:
return False
# Set the standard datafields - except for datetime
# 对于除了 datetime 之外的列,把数据根据列名添加到 line 中
for datafield in self.getlinealiases():
if datafield == 'datetime':
continue
# get the column index
colidx = getattr(self.params, datafield)
if colidx < 0:
# column not present -- skip
continue
# get the line to be set
line = getattr(self.lines, datafield)
# indexing for pandas: 1st is colum, then row
line[0] = row[colidx]
# datetime
# 对于 datetime,获取 datetime 所在列的 index,然后获取时间
colidx = getattr(self.params, 'datetime')
tstamp = row[colidx]
# convert to float via datetime and store it
# 把时间戳转化成具体的 datetime 格式,然后转化成数字
dt = tstamp.to_pydatetime()
dtnum = date2num(dt)
# get the line to be set
# 获取 datetime 的 line,然后保存这个数字
line = getattr(self.lines, 'datetime')
line[0] = dtnum
# Done ... return
return True
class PandasData(feed.DataBase):
'''
Uses a Pandas DataFrame as the feed source, using indices into column
names (which can be "numeric")
This means that all parameters related to lines must have numeric
values as indices into the tuples
Params:
- ``nocase`` (default *True*) case insensitive match of column names
Note:
- The ``dataname`` parameter is a Pandas DataFrame
- Values possible for datetime
- None: the index contains the datetime
- -1: no index, autodetect column
- >= 0 or string: specific colum identifier
- For other lines parameters
- None: column not present
- -1: autodetect
- >= 0 or string: specific colum identifier
'''
# 参数及其含义
params = (
('nocase', True),
# Possible values for datetime (must always be present)
# None : datetime is the "index" in the Pandas Dataframe
# -1 : autodetect position or case-wise equal name
# >= 0 : numeric index to the colum in the pandas dataframe
# string : column name (as index) in the pandas dataframe
('datetime', None),
# Possible values below:
# None : column not present
# -1 : autodetect position or case-wise equal name
# >= 0 : numeric index to the colum in the pandas dataframe
# string : column name (as index) in the pandas dataframe
('open', -1),
('high', -1),
('low', -1),
('close', -1),
('volume', -1),
('openinterest', -1),
)
# 数据的列名
datafields = [
'datetime', 'open', 'high', 'low', 'close', 'volume', 'openinterest'
]
# 类初始化
def __init__(self):
super(PandasData, self).__init__()
# these "colnames" can be strings or numeric types
# 列的名字,列表格式
colnames = list(self.p.dataname.columns.values)
# 如果 datetime 在 index 中
if self.p.datetime is None:
# datetime is expected as index col and hence not returned
pass
# try to autodetect if all columns are numeric
# 尝试判断 cstrings 是不是字符串,把不是字符串的过滤掉
cstrings = filter(lambda x: isinstance(x, string_types), colnames)
# 如果有一个是字符串,那么 colsnumeric 就是 False,只有全部是数字的情况下,才会返回 True
colsnumeric = not len(list(cstrings))
# Where each datafield find its value
# 定义一个字典
self._colmapping = dict()
# Build the column mappings to internal fields in advance
# 遍历每个列
for datafield in self.getlinealiases():
# 列所在的 index
defmapping = getattr(self.params, datafield)
# 如果列的 index 是数字并且小于 0,需要自动探测
if isinstance(defmapping, integer_types) and defmapping < 0:
# autodetection requested
for colname in colnames:
# 如果列名是字符串
if isinstance(colname, string_types):
# 如果没有大小写的区别,对比小写状态是否相等,如果相等就代表找到了,否则就直接对比是否相等
if self.p.nocase:
found = datafield.lower() == colname.lower()
else:
found = datafield == colname
# 如果找到了,那么就把 datafield 和 colname 进行一一对应,然后退出这个循环,继续 datafield
if found:
self._colmapping[datafield] = colname
break
# 如果找了一遍 df 的列没有找到,就设置成 None
if datafield not in self._colmapping:
# autodetection requested and not found
self._colmapping[datafield] = None
continue
# 如果 datafield 用户自己进行了定义,那么就直接使用用户定义的
else:
# all other cases -- used given index
self._colmapping[datafield] = defmapping
# 开始处理数据
def start(self):
super(PandasData, self).start()
# 开始之前,先重新设置 _idx
# reset the length with each start
self._idx = -1
# Transform names (valid for .ix) into indices (good for .iloc)
# 如果大小写不敏感,就把数据的列名转化成小写,如果敏感,保持原样
if self.p.nocase:
colnames = [x.lower() for x in self.p.dataname.columns.values]
else:
colnames = [x for x in self.p.dataname.columns.values]
# 对于 datafield 和列名进行迭代
for k, v in self._colmapping.items():
# 如果列名是 None 的话,代表这个列很可能是时间
if v is None:
continue # special marker for datetime
# 如果列名是字符串的话,如果大小写不敏感,就先转化成小写,如果不敏感,忽略,然后根据列名得到列所在的 index
if isinstance(v, string_types):
# 这下面的一些代码似乎有些无效,感觉可以忽略,直接使用 self._colmapping[k] = colnames.index(v)替代就好了
try:
if self.p.nocase:
v = colnames.index(v.lower())
else:
v = colnames.index(v)
except ValueError as e:
defmap = getattr(self.params, k)
if isinstance(defmap, integer_types) and defmap < 0:
v = None
else:
raise e # let user now something failed
# 如果不是字符串,用户自定义了具体的整数,直接使用用户自定义的
self._colmapping[k] = v
def _load(self):
# 每次 load 一行,_idx 每次加 1
self._idx += 1
# 如果 _idx 已经大于了数据的长度,返回 False
if self._idx >= len(self.p.dataname):
# exhausted all rows
return False
# Set the standard datafields
# 循环 datafield
for datafield in self.getlinealiases():
# 如果是时间,继续上面的循环
if datafield == 'datetime':
continue
colindex = self._colmapping[datafield]
# 如果列的 index 是 None,继续上面的循环
if colindex is None:
# datafield signaled as missing in the stream: skip it
continue
# get the line to be set
line = getattr(self.lines, datafield)
# indexing for pandas: 1st is colum, then row
# 使用 iloc 读取 dataframe 的数据,感觉效率一般
line[0] = self.p.dataname.iloc[self._idx, colindex]
# datetime conversion
coldtime = self._colmapping['datetime']
# 如果 datetime 所在的列是 None 的话,直接通过 index 获取时间,如果不是 None 的话,通过 iloc 获取时间数据
if coldtime is None:
# standard index in the datetime
tstamp = self.p.dataname.index[self._idx]
else:
# it's in a different column ... use standard column index
tstamp = self.p.dataname.iloc[self._idx, coldtime]
# convert to float via datetime and store it
# 转换时间数据并保存
dt = tstamp.to_pydatetime()
dtnum = date2num(dt)
self.lines.datetime[0] = dtnum
# Done ... return
return True from __future__ import (absolute_import, division, print_function,
unicode_literals)
from backtrader import date2num
import backtrader.feed as feed
# 这个类是 backtrader 对接 Blaze 数据的类
# blaze 介绍可以看这个:https://blaze.readthedocs.io/en/latest/index.html
class BlazeData(feed.DataBase):
'''
Support for `Blaze <blaze.pydata.org>`_ ``Data`` objects.
Only numeric indices to columns are supported.
Note:
- The ``dataname`` parameter is a blaze ``Data`` object
- A negative value in any of the parameters for the Data lines
indicates it's not present in the DataFrame
it is
'''
# 参数
params = (
# datetime must be present
('datetime', 0),
# pass -1 for any of the following to indicate absence
('open', 1),
('high', 2),
('low', 3),
('close', 4),
('volume', 5),
('openinterest', 6),
)
# 列名称
datafields = [
'datetime', 'open', 'high', 'low', 'close', 'volume', 'openinterest'
]
# 开始,直接把数据文件使用 iter 迭代,接下来 _load 的时候每次读取一行
def start(self):
super(BlazeData, self).start()
# reset the iterator on each start
self._rows = iter(self.p.dataname)
# load 数据
def _load(self):
# 尝试获取下一行的数据,如果不存在,那么报错,返回 False,代表数据已经 load 完毕
try:
row = next(self._rows)
except StopIteration:
return False
# Set the standard datafields - except for datetime
# 设置除了时间之外的其他数据,这个跟 CSV 操作差不多
for datafield in self.datafields[1:]:
# get the column index
colidx = getattr(self.params, datafield)
if colidx < 0:
# column not present -- skip
continue
# get the line to be set
line = getattr(self.lines, datafield)
line[0] = row[colidx]
# datetime - assumed blaze always serves a native datetime.datetime
# 处理时间部分,这部分操作相比于 CSV 部分的操作简单了很多,效率上应该也比较高,理论上会比 CSV 快一些
# 获取数据第一列的 index
colidx = getattr(self.params, self.datafields[0])
# 获取时间数据
dt = row[colidx]
# 把时间转化为数字
dtnum = date2num(dt)
# get the line to be set
# 获取该列的 line,然后添加数据
line = getattr(self.lines, self.datafields[0])
line[0] = dtnum
# Done ... return
return True from __future__ import (absolute_import, division, print_function,
unicode_literals)
import backtrader as bt
import backtrader.feed as feed
from ..utils import date2num
import datetime as dt
# 时间周期的对应
TIMEFRAMES = dict(
(
(bt.TimeFrame.Seconds, 's'),
(bt.TimeFrame.Minutes, 'm'),
(bt.TimeFrame.Days, 'd'),
(bt.TimeFrame.Weeks, 'w'),
(bt.TimeFrame.Months, 'm'),
(bt.TimeFrame.Years, 'y'),
)
)
# backtrader 从 influxDB 获取数据
class InfluxDB(feed.DataBase):
# 导入包
frompackages = (
('influxdb', [('InfluxDBClient', 'idbclient')]),
('influxdb.exceptions', 'InfluxDBClientError')
)
# 参数
params = (
('host', '127.0.0.1'),
('port', '8086'),
('username', None),
('password', None),
('database', None),
('timeframe', bt.TimeFrame.Days),
('startdate', None),
('high', 'high_p'),
('low', 'low_p'),
('open', 'open_p'),
('close', 'close_p'),
('volume', 'volume'),
('ointerest', 'oi'),
)
# 开始
def start(self):
super(InfluxDB, self).start()
# 尝试连接数据库
try:
self.ndb = idbclient(self.p.host, self.p.port, self.p.username,
self.p.password, self.p.database)
except InfluxDBClientError as err:
print('Failed to establish connection to InfluxDB: %s' % err)
# 具体的时间周期
tf = '{multiple}{timeframe}'.format(
multiple=(self.p.compression if self.p.compression else 1),
timeframe=TIMEFRAMES.get(self.p.timeframe, 'd'))
# 开始时间
if not self.p.startdate:
st = '<= now()'
else:
st = '>= \'%s\'' % self.p.startdate
# The query could already consider parameters like fromdate and todate
# to have the database skip them and not the internal code
# 具体的数据库获取数据需要设置的命令
qstr = ('SELECT mean("{open_f}") AS "open", mean("{high_f}") AS "high", '
'mean("{low_f}") AS "low", mean("{close_f}") AS "close", '
'mean("{vol_f}") AS "volume", mean("{oi_f}") AS "openinterest" '
'FROM "{dataname}" '
'WHERE time {begin} '
'GROUP BY time({timeframe}) fill(none)').format(
open_f=self.p.open, high_f=self.p.high,
low_f=self.p.low, close_f=self.p.close,
vol_f=self.p.volume, oi_f=self.p.ointerest,
timeframe=tf, begin=st, dataname=self.p.dataname)
# 获取数据
try:
dbars = list(self.ndb.query(qstr).get_points())
except InfluxDBClientError as err:
print('InfluxDB query failed: %s' % err)
# 迭代数据
self.biter = iter(dbars)
def _load(self):
# 尝试获取下个 bar 的数据,然后添加到 line 中
try:
bar = next(self.biter)
except StopIteration:
return False
self.l.datetime[0] = date2num(dt.datetime.strptime(bar['time'],
'%Y-%m-%dT%H:%M:%SZ'))
self.l.open[0] = bar['open']
self.l.high[0] = bar['high']
self.l.low[0] = bar['low']
self.l.close[0] = bar['close']
self.l.volume[0] = bar['volume']
return True from __future__ import (absolute_import, division, print_function,
unicode_literals)
from datetime import date, datetime, time
from .. import feed
from ..utils import date2num
# 解析一个自定义的 csv data,主要用于测试使用。
class BacktraderCSVData(feed.CSVDataBase):
'''
Parses a self-defined CSV Data used for testing.
Specific parameters:
- ``dataname``: The filename to parse or a file-like object
'''
# 对每行数据进行处理
def _loadline(self, linetokens):
# 把每行数据进行迭代
itoken = iter(linetokens)
# 时间处理
dttxt = next(itoken) # Format is YYYY-MM-DD - skip char 4 and 7
dt = date(int(dttxt[0:4]), int(dttxt[5:7]), int(dttxt[8:10]))
# 如果列有 8 个,代表存在时间,第二列是时间,对时间进行处理,如果不是 8 列,代表没有时间,时间是用 sessionend
if len(linetokens) == 8:
tmtxt = next(itoken) # Format if present HH:MM:SS, skip 3 and 6
tm = time(int(tmtxt[0:2]), int(tmtxt[3:5]), int(tmtxt[6:8]))
else:
tm = self.p.sessionend # end of the session parameter
# 分别设置各个 line
self.lines.datetime[0] = date2num(datetime.combine(dt, tm))
self.lines.open[0] = float(next(itoken))
self.lines.high[0] = float(next(itoken))
self.lines.low[0] = float(next(itoken))
self.lines.close[0] = float(next(itoken))
self.lines.volume[0] = float(next(itoken))
self.lines.openinterest[0] = float(next(itoken))
return True
class BacktraderCSV(feed.CSVFeedBase):
# 类,DataCls 设置了值为 BacktraderCSVData 类
DataCls = BacktraderCSVData from __future__ import (absolute_import, division, print_function,
unicode_literals)
from . import GenericCSVData
class MT4CSVData(GenericCSVData):
'''
Parses a `Metatrader4 <https://www.metaquotes.net/en/metatrader4>`_ History
center CSV exported file.
Specific parameters (or specific meaning):
- ``dataname``: The filename to parse or a file-like object
- Uses GenericCSVData and simply modifies the params
'''
# MT4 数据的这个类,继承了 CSV 类,仅仅只是修改了参数
params = (
('dtformat', '%Y.%m.%d'),
('tmformat', '%H:%M'),
('datetime', 0),
('time', 1),
('open', 2),
('high', 3),
('low', 4),
('close', 5),
('volume', 6),
('openinterest', -1),
)