DataLoader
- Dataset不能满足需求需自定义继承
torch.utils.data.Dataset
时需要override
__init__
,__getitem__
,__len__
,否则DataLoader
导入自定义Dataset
时缺少上述函数会导致NotImplementedError
错误
Numpy 广播机制:
- 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
- 输出数组的shape是输入数组shape的各个轴上的最大值
- 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错
- 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值
CUDA在pytorch中的扩展:
torch.utils.ffi中使用create_extension扩充:
def create_extension(name, headers, sources, verbose=True, with_cuda=False,
package=False, relative_to='.', **kwargs):
"""Creates and configures a cffi.FFI object, that builds PyTorch extension.
Arguments:
name (str): package name. Can be a nested module e.g. ``.ext.my_lib``.
headers (str or List[str]): list of headers, that contain only exported
functions
sources (List[str]): list of sources to compile.
verbose (bool, optional): if set to ``False``, no output will be printed
(default: True).
with_cuda (bool, optional): set to ``True`` to compile with CUDA headers
(default: False)
package (bool, optional): set to ``True`` to build in package mode (for modules
meant to be installed as pip packages) (default: False).
relative_to (str, optional): path of the build file. Required when
``package is True``. It's best to use ``__file__`` for this argument.
kwargs: additional arguments that are passed to ffi to declare the
extension. See `Extension API reference`_ for details.
.. _`Extension API reference`: https://docs.python.org/3/distutils/apiref.html#distutils.core.Extension
"""
base_path = os.path.abspath(os.path.dirname(relative_to))
name_suffix, target_dir = _create_module_dir(base_path, name)
if not package:
cffi_wrapper_name = '_' + name_suffix
else:
cffi_wrapper_name = (name.rpartition('.')[0] +
'.{0}._{0}'.format(name_suffix))
wrapper_source, include_dirs = _setup_wrapper(with_cuda)
include_dirs.extend(kwargs.pop('include_dirs', []))
if os.sys.platform == 'win32':
library_dirs = glob.glob(os.getenv('CUDA_PATH', '') + '/lib/x64')
library_dirs += glob.glob(os.getenv('NVTOOLSEXT_PATH', '') + '/lib/x64')
here = os.path.abspath(os.path.dirname(__file__))
lib_dir = os.path.join(here, '..', '..', 'lib')
library_dirs.append(os.path.join(lib_dir))
else:
library_dirs = []
library_dirs.extend(kwargs.pop('library_dirs', []))
if isinstance(headers, str):
headers = [headers]
all_headers_source = ''
for header in headers:
with open(os.path.join(base_path, header), 'r') as f:
all_headers_source += f.read() + '\n\n'
ffi = cffi.FFI()
sources = [os.path.join(base_path, src) for src in sources]
# NB: TH headers are C99 now
kwargs['extra_compile_args'] = ['-std=c99'] + kwargs.get('extra_compile_args', [])
ffi.set_source(cffi_wrapper_name, wrapper_source + all_headers_source,
sources=sources,
include_dirs=include_dirs,
library_dirs=library_dirs, **kwargs)
ffi.cdef(_typedefs + all_headers_source)
_make_python_wrapper(name_suffix, '_' + name_suffix, target_dir)
def build():
_build_extension(ffi, cffi_wrapper_name, target_dir, verbose)
ffi.build = build
return ffi