pax_global_header 0000666 0000000 0000000 00000000064 15046477414 0014526 g ustar 00root root 0000000 0000000 52 comment=8fe83f83ed005c0e7342a4c69d1484e2d09fdbc3
python-opt-einsum-fx-0.1.4/ 0000775 0000000 0000000 00000000000 15046477414 0015562 5 ustar 00root root 0000000 0000000 python-opt-einsum-fx-0.1.4/.github/ 0000775 0000000 0000000 00000000000 15046477414 0017122 5 ustar 00root root 0000000 0000000 python-opt-einsum-fx-0.1.4/.github/workflows/ 0000775 0000000 0000000 00000000000 15046477414 0021157 5 ustar 00root root 0000000 0000000 python-opt-einsum-fx-0.1.4/.github/workflows/release.yml 0000664 0000000 0000000 00000001235 15046477414 0023323 0 ustar 00root root 0000000 0000000 name: Upload Python Package
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
python-opt-einsum-fx-0.1.4/.github/workflows/tests.yml 0000664 0000000 0000000 00000002164 15046477414 0023047 0 ustar 00root root 0000000 0000000 name: Check Syntax and Run Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.9]
torch-version: [1.8.0, 1.10.0]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install flake8
run: |
pip install flake8
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Install dependencies
env:
TORCH: "${{ matrix.torch-version }}"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
python -m pip install --upgrade pip
pip install wheel
pip install torch==${TORCH} torchvision torchaudio -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install .
- name: Install pytest
run: |
pip install pytest
- name: Test with pytest
run: |
pytest --doctest-modules --ignore=docs/ .
python-opt-einsum-fx-0.1.4/.gitignore 0000664 0000000 0000000 00000003417 15046477414 0017557 0 ustar 00root root 0000000 0000000 # Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode python-opt-einsum-fx-0.1.4/.readthedocs.yaml 0000664 0000000 0000000 00000000623 15046477414 0021012 0 ustar 00root root 0000000 0000000 # .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
install:
- method: pip
path: . python-opt-einsum-fx-0.1.4/CHANGELOG.md 0000664 0000000 0000000 00000001450 15046477414 0017373 0 ustar 00root root 0000000 0000000 # Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
Most recent change on the bottom.
## [Unreleased]
## 0.1.4 - 2021-11-7
### Added
- `opt_einsum_fx.__version__`
- Partially symbolic shape propagation for efficient einsum optimization (#15)
## 0.1.3 - 2021-10-29
### Added
- PyTorch 1.10 compatability
### Fixed
- Added `packaging` to dependency list
## 0.1.2 - 2021-06-28
### Added
- PyTorch 1.9 compatibility
## 0.1.1 - 2021-05-27
### Added
- Docs
- PyPI package
### Fixed
- `jitable` no longer makes some FX nodes' `.args` lists (technically not allowed) instead of tuples
## [0.1.0] - 2021-05-17 python-opt-einsum-fx-0.1.4/LICENSE 0000664 0000000 0000000 00000002050 15046477414 0016564 0 ustar 00root root 0000000 0000000 MIT License
Copyright (c) 2021 Alby M.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
python-opt-einsum-fx-0.1.4/README.md 0000664 0000000 0000000 00000005141 15046477414 0017042 0 ustar 00root root 0000000 0000000 # opt_einsum_fx
[](https://opt-einsum-fx.readthedocs.io/en/latest/?badge=latest)
Optimizing einsums and functions involving them using [`opt_einsum`](https://optimized-einsum.readthedocs.io/en/stable/) and PyTorch [FX](https://pytorch.org/docs/stable/fx.html) compute graphs.
Issues, questions, PRs, and any thoughts about further optimizing these kinds of operations are welcome!
For more information please see [the docs](https://opt-einsum-fx.readthedocs.io/en/stable/).
## Installation
### PyPI
The latest release can be installed from PyPI:
```bash
$ pip install opt_einsum_fx
```
### Source
To get the latest code, run:
```bash
$ git clone https://github.com/Linux-cpp-lisp/opt_einsum_fx.git
```
and install it by running
```bash
$ cd opt_einsum_fx/
$ pip install .
```
You can run the tests with
```bash
$ pytest tests/
```
## Minimal example
```python
import torch
import torch.fx
import opt_einsum_fx
def einmatvecmul(a, b, vec):
"""Batched matrix-matrix-vector product using einsum"""
return torch.einsum("zij,zjk,zk->zi", a, b, vec)
graph_mod = torch.fx.symbolic_trace(einmatvecmul)
print("Original code:\n", graph_mod.code)
graph_opt = opt_einsum_fx.optimize_einsums_full(
model=graph_mod,
example_inputs=(
torch.randn(7, 4, 5),
torch.randn(7, 5, 3),
torch.randn(7, 3)
)
)
print("Optimized code:\n", graph_opt.code)
```
outputs
```
Original code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec); a = b = vec = None
return einsum_1
Optimized code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b); vec = b = None
einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a); einsum_1 = a = None
return einsum_2
```
We can measure the performance improvement (this is on a CPU):
```python
from torch.utils.benchmark import Timer
batch = 1000
a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8)
g = {"f": graph_mod, "a": a, "b": b, "vec": vec}
t_orig = Timer("f(a, b, vec)", globals=g)
print(t_orig.timeit(10_000))
g["f"] = graph_opt
t_opt = Timer("f(a, b, vec)", globals=g)
print(t_opt.timeit(10_000))
```
gives ~2x improvement:
```
f(a, b, vec)
276.58 us
1 measurement, 10000 runs , 1 thread
f(a, b, vec)
118.84 us
1 measurement, 10000 runs , 1 thread
```
Depending on your function and dimensions you may see even larger improvements.
## License
`opt_einsum_fx` is distributed under an [MIT license](LICENSE). python-opt-einsum-fx-0.1.4/docs/ 0000775 0000000 0000000 00000000000 15046477414 0016512 5 ustar 00root root 0000000 0000000 python-opt-einsum-fx-0.1.4/docs/Makefile 0000664 0000000 0000000 00000001172 15046477414 0020153 0 ustar 00root root 0000000 0000000 # Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
python-opt-einsum-fx-0.1.4/docs/api.rst 0000664 0000000 0000000 00000000141 15046477414 0020011 0 ustar 00root root 0000000 0000000 Full Reference
==============
.. automodule:: opt_einsum_fx
:members:
:imported-members: python-opt-einsum-fx-0.1.4/docs/conf.py 0000664 0000000 0000000 00000003620 15046477414 0020012 0 ustar 00root root 0000000 0000000 # Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = "opt_einsum_fx"
copyright = "2021, Linux-cpp-lisp"
author = "Linux-cpp-lisp"
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_rtd_theme"]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
python-opt-einsum-fx-0.1.4/docs/index.rst 0000664 0000000 0000000 00000000743 15046477414 0020357 0 ustar 00root root 0000000 0000000 .. opt_einsum_fx documentation master file, created by
sphinx-quickstart on Wed May 26 15:06:08 2021.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to opt_einsum_fx's documentation!
=========================================
.. toctree::
:maxdepth: 2
:caption: Contents:
tutorial.rst
api.rst
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
python-opt-einsum-fx-0.1.4/docs/make.bat 0000664 0000000 0000000 00000001433 15046477414 0020120 0 ustar 00root root 0000000 0000000 @ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
python-opt-einsum-fx-0.1.4/docs/tutorial.rst 0000664 0000000 0000000 00000007732 15046477414 0021120 0 ustar 00root root 0000000 0000000 Tutorial
========
``opt_einsum_fx`` is a library for optimizng einsums and functions involving them using `opt_einsum `_ and `PyTorch FX `_ compute graphs.
The library currently supports:
- Fusing multiple einsums into one
- Optimizing einsums using the `opt_einsum `_ library
- Fusing multiplication and division with scalar constants, including fusing _through_ operations, like einsum, that commute with scalar multiplication.
- Placing multiplication by fused scalar constants onto the smallest intermediate in a chain of operations that commute with scalar multiplication.
``opt_einsum_fx`` is based on `torch.fx `_, a framework for converting between PyTorch Python code and a programatically manipulable compute graph. To use this package, it must be possible to get your function or model as a `torch.fx.Graph`: the limitations of FX's symbolic tracing are discussed `here `_.
Minimal example
---------------
.. code-block:: python
import torch
import torch.fx
import opt_einsum_fx
def einmatvecmul(a, b, vec):
"""Batched matrix-matrix-vector product using einsum"""
return torch.einsum("zij,zjk,zk->zi", a, b, vec)
graph_mod = torch.fx.symbolic_trace(einmatvecmul)
print("# Original code:\n", graph_mod.code)
graph_opt = opt_einsum_fx.optimize_einsums_full(
model=graph_mod,
example_inputs=(
torch.randn(7, 4, 5),
torch.randn(7, 5, 3),
torch.randn(7, 3)
)
)
print("# Optimized code:\n", graph_opt.code)
outputs:
.. code-block:: python
# Original code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec); a = b = vec = None
return einsum_1
# Optimized code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b); vec = b = None
einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a); einsum_1 = a = None
return einsum_2
The ``optimize_einsums_full`` function has four passes:
1. Scalar accumulation --- use the multilinearity of einsum to fuse all constant coefficients and divisors of operands and outputs
2. Fusing einsums --- gives greater flexibility to (3)
3. Optimized contraction with ``opt_einsum``
4. Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results
We can measure the performance improvement (this is on a CPU):
.. code-block:: python
from torch.utils.benchmark import Timer
batch = 1000
a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8)
g = {"f": graph_mod, "a": a, "b": b, "vec": vec}
t_orig = Timer("f(a, b, vec)", globals=g)
print(t_orig.timeit(10_000))
g["f"] = graph_opt
t_opt = Timer("f(a, b, vec)", globals=g)
print(t_opt.timeit(10_000))
gives ~2x improvement:
.. code-block:: none
f(a, b, vec)
276.58 us
1 measurement, 10000 runs , 1 thread
f(a, b, vec)
118.84 us
1 measurement, 10000 runs , 1 thread
Depending on your function and dimensions you may see even larger improvements.
JIT
---
Currently, pure Python and TorchScript have different call signatures for `torch.tensordot` and `torch.permute`, both of which can appear in optimized einsums:
.. code-block:: python
graph_script = torch.jit.script(graph_opt) # => RuntimeError: Arguments for call are not valid...
A function is provided to convert ``torch.fx.GraphModule`` s containing these operations from their Python signatures — the default — to a TorchScript compatible form:
.. code-block:: python
graph_script = torch.jit.script(opt_einsum_fx.jitable(graph_opt)) python-opt-einsum-fx-0.1.4/opt_einsum_fx/ 0000775 0000000 0000000 00000000000 15046477414 0020441 5 ustar 00root root 0000000 0000000 python-opt-einsum-fx-0.1.4/opt_einsum_fx/__init__.py 0000664 0000000 0000000 00000000553 15046477414 0022555 0 ustar 00root root 0000000 0000000 __version__ = "0.1.4"
from ._script import jitable
from ._opt_ein import optimize_einsums, optimize_einsums_full
from ._fuse import fuse_einsums, fuse_scalars
from ._efficient_shape_prop import EfficientShapeProp
__all__ = [
"jitable",
"optimize_einsums",
"optimize_einsums_full",
"fuse_einsums",
"fuse_scalars",
"EfficientShapeProp",
]
python-opt-einsum-fx-0.1.4/opt_einsum_fx/_efficient_shape_prop.py 0000664 0000000 0000000 00000007237 15046477414 0025337 0 ustar 00root root 0000000 0000000 from typing import Any, NamedTuple
import opt_einsum
import torch
from torch.fx.node import Node
from ._fuse import _EINSUM_FUNCS
class SimpleMeta(NamedTuple):
"""
The full ShapeProp defines and uses a NamedTuple to
store a whole bunch of metadata about the tensors
going into and out of the Node op. But we don't
have most of that info, and anyway, I don't think
most of it's used in opt_einsum or opt_einsum_fx.
(These are only concerned with computing a summation
order.)
Rather than give dummy or default values, which I
only *assume* would be fine, I'm defining a NamedTuple
with only the values we actually know. So if I'm wrong
we will get a very clear error message, rather than
some invisible error.
"""
shape: torch.Size
dtype: torch.dtype
class EfficientShapeProp(torch.fx.Interpreter):
"""
Like ShapeProp, traverses a graph Node-by-Node
and records the shape and type of the result
into each Node.
Except we treat 'einsum' as a special case.
We don't actually execute 'einsum' on tensors,
since the einsums will typically not be optimized
yet (ShapeProp is called before optimization),
and inefficient summation order can create
enormous intermediate tensors, which often creates
needless out-of-memory errors.
So we override 'run_node' only for 'einsums'.
It's straightforward to determine the shape of the
result just from the output indices.
(The call to opt_einsum that will typically follow
this, also doesn't actually build the tensors
during its exploration.)
"""
def run_node(self, n: Node) -> Any:
if n.op == "call_function" and n.target in _EINSUM_FUNCS:
args, kwargs = self.fetch_args_kwargs_from_env(n)
equation, *operands = args
shapes = [op.shape for op in operands]
assert len({op.dtype for op in operands}) == 1
meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype)
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape)
elif n.op == "call_function" and n.target == torch.tensordot:
args, kwargs = self.fetch_args_kwargs_from_env(n)
shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]]
shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]]
assert len({op.dtype for op in args}) == 1
meta = SimpleMeta(shape_a + shape_b, args[0].dtype)
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape)
else:
result = super().run_node(n)
if isinstance(result, torch.Tensor):
meta = SimpleMeta(result.shape, result.dtype)
else:
meta = None
n.meta = dict()
n.meta['tensor_meta'] = meta
n.meta['type'] = type(result)
return result
def propagate(self, *args):
return super().run(*args)
def einsum_shape(subscripts, *shapes):
"""
Given an einsum equation and input shapes, returns the output
shape of the einsum.
Args:
subscripts: the einsum formula
shapes: the input shapes
"""
Shaped = NamedTuple('Shaped', [('shape', tuple)])
input_subscripts, output_subscript, _ = opt_einsum.parser.parse_einsum_input(
(subscripts,) + tuple(Shaped(shape) for shape in shapes)
)
dims = {
i: dim
for ii, shape in zip(input_subscripts.split(','), shapes)
for i, dim in zip(ii, shape)
}
return tuple(dims[i] for i in output_subscript)
python-opt-einsum-fx-0.1.4/opt_einsum_fx/_fuse.py 0000664 0000000 0000000 00000032726 15046477414 0022126 0 ustar 00root root 0000000 0000000 from typing import List, Optional, Tuple
import string
import copy
import operator
import numbers
import torch
from torch import fx
from opt_einsum.parser import find_output_str
from .fx_utils import get_shape
_EINSUM_FUNCS = {torch.functional.einsum, torch.einsum}
# == Einsum fusion ==
def _get_einstrs(einstr: str) -> Tuple[List[str], str]:
if "..." in einstr:
raise NotImplementedError("Ellipsis `...` in einsum string not supported yet")
tmp = einstr.split("->")
if len(tmp) == 1:
ops = tmp[0]
out = find_output_str(ops)
elif len(tmp) == 2:
ops, out = tmp
else:
raise ValueError(f"Invalid einstr {einstr}")
return ops.split(","), out
def fuse_einsums(graph: fx.Graph, in_place: bool = False) -> fx.Graph:
"""Fuse einsums when possible.
When the output of one einsum is only used as an operand in another einsum, the two einsums can be fused into one.
Example:
.. code-block:: python
def fusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
return torch.einsum("ik,ij->i", z, x)
g = torch.fx.symbolic_trace(fusable)
print(fuse_einsums(g.graph).python_code(""))
gives::
import torch
def forward(self, x, y):
einsum_2 = torch.functional.einsum('ib,bk,ij->i', x, y, x); x = y = None
return einsum_2
Args:
graph: the graph to process.
in_place (bool, optional): whether to process ``graph`` in place.
Returns:
The graph with fused einsums.
"""
if not in_place:
graph = copy.deepcopy(graph)
for node in graph.nodes:
if node.op == "call_function" and node.target in _EINSUM_FUNCS:
our_inp_einstrs, our_out_einstr = _get_einstrs(node.args[0])
assert len(our_inp_einstrs) == len(node.args) - 1
avail_letters = iter(
set(string.ascii_lowercase)
- set.union(*(set(e) for e in our_inp_einstrs))
)
new_our_einstrs = []
new_our_args = []
we_fused_nodes = []
# Iterate over operands
for inp_idex, inp in enumerate(node.args[1:]):
if (
inp.op == "call_function"
and inp.target in _EINSUM_FUNCS
and len(inp.users) == 1
):
# This operand is the output of another einsum, and is not used by any other operation
# As a result, we can fuse it
its_inp_einstrs, its_out_einstr = _get_einstrs(inp.args[0])
if len(its_out_einstr) != len(our_inp_einstrs[inp_idex]):
raise RuntimeError(
f"Inconsistent rank: einsum `{node}`'s input {inp_idex} is the result of einsum {inp}; the output of `{inp}` is labeled `{its_out_einstr}` (rank {len(its_out_einstr)}), but the corresponding input of `{node}` is labeled `{our_inp_einstrs[inp_idex]}` (rank {len(our_inp_einstrs[inp_idex])})"
)
# First, we need to figure out which of its output dimensions correspond to our dimensions:
its_dim_to_ours = dict(
zip(its_out_einstr, our_inp_einstrs[inp_idex])
)
# assign any labels that don't show up in the output of the previous einsum --- and thus dont have labels in the current einsum --- to new letters
its_remaining_labels = set.union(
*(set(e) for e in its_inp_einstrs)
) - set(its_dim_to_ours.keys())
try:
its_dim_to_ours.update(
dict((i, next(avail_letters)) for i in its_remaining_labels)
)
except StopIteration:
# We ran out of letters
raise NotImplementedError(
f"At einsum {node}, ran out of letters when trying to fuse parameter einsum {inp}. A fallback for this case is not yet implimented."
)
else:
# We had enough letters, finish adding the fuse
del its_remaining_labels
new_our_args.extend(inp.args[1:])
new_our_einstrs.extend(
"".join(its_dim_to_ours[d] for d in es)
for es in its_inp_einstrs
)
we_fused_nodes.append(inp)
else:
# This argument is not from an einsum, or is from an einsum that is used elsewhere as well
# Thus we just pass it through
new_our_einstrs.append(our_inp_einstrs[inp_idex])
new_our_args.append(inp)
# -- end iter over prev einsum inputs --
# Set the new values for the einstrs
node.args = (f"{','.join(new_our_einstrs)}->{our_out_einstr}",) + tuple(
new_our_args
)
# Remove fused inputs
for to_remove in we_fused_nodes:
graph.erase_node(to_remove)
# -- end case for einsum nodes --
# -- end iter over nodes --
return graph
# == Scalar fusion ==
#
# Note that in general we do not support scalar fusion through in-place operations; it complicates following things through the compute graph too much
# TODO: ^ ???
# TODO: should the accumulation of constants happen in more than double precision?
def _get_node_and_scalar(node: fx.Node) -> Tuple[fx.Node, Optional[numbers.Number]]:
"""Get a multiplicative scalar for an operation, if applicable."""
# This supports in-place *= and /= because fx traces them as normal operator.mul/div.
if node.op == "call_function":
if node.target == operator.mul or node.target == torch.mul:
if isinstance(node.args[0], numbers.Number):
return node.args[1], node.args[0]
elif isinstance(node.args[1], numbers.Number):
return node.args[0], node.args[1]
elif node.target == operator.truediv or node.target == torch.div:
if isinstance(node.args[1], numbers.Number):
return node.args[0], 1.0 / node.args[1]
elif node.op == "call_method":
# TODO: this could _technically_ be wrong if the nodes `self` argument is not a (proxy to) a Tensor
if node.target == "mul":
if isinstance(node.args[1], numbers.Number):
return node.args[0], node.args[1]
elif node.target == "div":
if isinstance(node.args[1], numbers.Number):
return node.args[0], 1.0 / node.args[1]
return node, None
# Operations that are (almost) "multilinear", in the sense that they commute with scalar multiplication of their operands
SCALAR_COMMUTE_OPS = [
torch.einsum,
torch.functional.einsum,
torch.tensordot,
torch.functional.tensordot,
"permute",
# "reshape",
"mul",
"div",
operator.mul,
operator.truediv,
]
def prod(x):
"""Compute the product of a sequence."""
out = 1
for a in x:
out *= a
return out
def fuse_scalars(graph: fx.Graph, in_place: bool = False) -> fx.Graph:
"""Use the multilinearity of einsum to unify and remove constant scalars around einsums.
Args:
graph: the graph to process.
in_place (bool, optional): whether to process ``graph`` in place.
Returns:
The graph with fused scalars.
"""
if not in_place:
graph = copy.deepcopy(graph)
# Clear any previous state this graph has
for node in graph.nodes:
if hasattr(node, "in_lin_chain"):
delattr(node, "in_lin_chain")
# Find chains of multilinear ops
seen_nodes = set()
linear_chains = []
for node in graph.nodes:
if id(node) in seen_nodes:
continue
# Determine a linear chain
cur_linear_chain = []
while (
id(node) not in seen_nodes
and getattr(node, "target", None) in SCALAR_COMMUTE_OPS
):
seen_nodes.add(id(node))
node.in_lin_chain = len(linear_chains)
cur_linear_chain.append(node)
# Continue building the chain regardless, since the merger uses this
users = list(node.users.keys())
if len(users) > 0:
# Get the next node in the chain
node = users[0]
else:
# This isn't used in the graph at all, break the chain
node = None
if len(users) != 1:
# End this chain
break
# If the next user, which is now in node, was seen but is itself in a linear chain, this means we merge them
# TODO: thoroughly test this
if hasattr(node, "in_lin_chain") and len(cur_linear_chain) > 0:
# Merge
merge_into = node.in_lin_chain
for n in cur_linear_chain:
n.in_lin_chain = merge_into
linear_chains[merge_into].extend(cur_linear_chain)
else:
# This is a new chain
linear_chains.append(cur_linear_chain)
# Accumulate scalars in them
scalars = []
for lin_chain_i, lin_chain in enumerate(linear_chains):
if len(lin_chain) < 2:
# There's nothing to do here: either the chain is empty,
# or there's only one operation — even if its a scalar multiplication,
# theres nothing for us to do with it
scalars.append(None)
continue
# Accumulate scalars
scalar_node_idexes = []
total_scalar = 1.0
for node_i, node in enumerate(lin_chain):
new_node, scalar = _get_node_and_scalar(node)
if scalar is not None:
total_scalar *= scalar
scalar_node_idexes.append(node_i)
is_all_scalars = len(scalar_node_idexes) == len(lin_chain)
# Remove scalar nodes
for node_i in scalar_node_idexes:
node = lin_chain[node_i]
new_node, scalar = _get_node_and_scalar(node)
assert scalar is not None
if is_all_scalars and node_i == len(lin_chain) - 1:
# If it's all scalars, we just put the total_scalar into the last operation
# and don't save a scalar for later
with graph.inserting_after(node):
new_node = graph.call_function(
operator.mul,
(total_scalar, new_node),
)
total_scalar = None
node.replace_all_uses_with(new_node)
graph.erase_node(node)
# Save the scalar for this chain
scalars.append(total_scalar)
# Remove all of the removed scalar operations from the lin chain
# See https://stackoverflow.com/a/11303234/1008938
for index in sorted(
(scalar_node_idexes[:-1] if is_all_scalars else scalar_node_idexes),
reverse=True,
):
del lin_chain[index]
del seen_nodes
# Make sure everything is still OK
graph.lint()
# Now we have chains without scalar operations; we can go through and add back in the scalars in the optimal place
for lin_chain_i, lin_chain in enumerate(linear_chains):
if (
len(lin_chain) == 0
or scalars[lin_chain_i] == 1.0
or scalars[lin_chain_i] is None
):
# Nothing to do with an empty chain
# No reason to add back a scalar that does nothing
# None signals don't process from above
continue
# Find the smallest argument or the output
smallest_node_i = None
smallest_arg_i = None
smallest_size = float("inf")
for node_i, node in enumerate(lin_chain):
for arg_i, arg in enumerate(node.args):
if not isinstance(arg, fx.Node):
continue
shape = get_shape(arg)
if shape is not None and prod(shape) < smallest_size:
smallest_node_i = node_i
smallest_arg_i = arg_i
smallest_size = prod(shape)
# Put the accumulated scalar on a node
if (smallest_node_i is None) or (
get_shape(lin_chain[-1]) is not None
and prod(get_shape(lin_chain[-1])) < smallest_size
):
# The output is the smallest, put it there
# OR there was no smallest argument, put it on the end of the chain
with graph.inserting_after(lin_chain[-1]):
new_node = graph.call_function(operator.mul, tuple()) # placeholder
lin_chain[-1].replace_all_uses_with(new_node)
new_node.args = (lin_chain[-1], scalars[lin_chain_i])
else:
# The smallest was someone's arg, so we replace that with a scalar multiplication:
with graph.inserting_before(lin_chain[smallest_node_i]):
new_arg = graph.call_function(
operator.mul,
(
lin_chain[smallest_node_i].args[smallest_arg_i],
scalars[lin_chain_i],
),
)
new_args = list(lin_chain[smallest_node_i].args)
new_args[smallest_arg_i] = new_arg
lin_chain[smallest_node_i].args = tuple(new_args)
graph.lint()
return graph
python-opt-einsum-fx-0.1.4/opt_einsum_fx/_opt_ein.py 0000664 0000000 0000000 00000014775 15046477414 0022625 0 ustar 00root root 0000000 0000000 from typing import Callable, Union
import warnings
import torch
from torch import fx
from ._efficient_shape_prop import EfficientShapeProp as ShapeProp
import opt_einsum
from opt_einsum.contract import _core_contract
from ._fuse import fuse_einsums, fuse_scalars, _EINSUM_FUNCS
from .fx_utils import get_shape
def optimize_einsums_full(
model: Union[torch.nn.Module, Callable, fx.Graph],
example_inputs: tuple,
contract_kwargs: dict = {},
tracer_class: type = fx.Tracer,
) -> Union[fx.GraphModule, fx.Graph]:
"""Optimize einsums in ``model`` for ``example_inputs``.
All of the restrictions of ``torch.fx`` symbolic tracing apply.
Applies, in order, four optimizations:
1. Scalar accumulation --- use the multilinearity of einsum to collect all constant coefficients and divisors of operands and outputs
2. Fusing einsums --- gives greater flexibility to (3)
3. Optimized contraction with ``opt_einsum``.
4. Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results
Args:
model (torch.nn.Module or callable or fx.Graph): the model, function, or ``fx.Graph`` to optimize.
example_inputs (tuple): arguments to ``model`` whose shapes will determine the einsum optimizations.
tracer_class (type, optional): the tracer class to use to turn ``model`` into an ``fx.Graph`` if it isn't already an ``fx.GraphModule`` or ``fx.Graph``.
Returns:
An optimized ``fx.GraphModule``, or if ``model`` is an ``fx.Graph``, an optimized ``fx.Graph``.
"""
output_graph = False
if isinstance(model, fx.GraphModule):
graph: fx.Graph = model.graph
elif isinstance(model, fx.Graph):
graph: fx.Graph = model
model = torch.nn.Module()
output_graph = True
else:
tracer: fx.Tracer = tracer_class()
graph: fx.Graph = tracer.trace(model)
model = tracer.root
# 1. Scalar accumulation
# without shape information, this just accumulates scalars and moves them to the end of chains of linear operations
graph = fuse_scalars(graph)
# 2. Fuse any einsums we can
# This gives opt_einsum the most freedom possible to rearange things
# Since we already moved scalars to the end of chains of linear operations, any scalars between linear operations should already have been moved
graph = fuse_einsums(graph, in_place=True)
out_mod = fx.GraphModule(model, graph)
# 3. Shape propagation
sp = ShapeProp(out_mod)
sp.run(*example_inputs)
# 4. Optimize einsums
out_mod.graph = optimize_einsums(out_mod.graph, contract_kwargs)
out_mod.recompile()
# 5. Shape prop (again)
# We need shapes to put the scalars in the best place
sp = ShapeProp(out_mod)
sp.run(*example_inputs)
# 6. Final scalar fusion to move scalars
out_mod.graph = fuse_scalars(out_mod.graph, in_place=True)
if output_graph:
return out_mod.graph
else:
out_mod.recompile()
return out_mod
# Based on "Proxy Retracing" example in https://pytorch.org/docs/stable/fx.html
def optimize_einsums(graph: fx.Graph, contract_kwargs: dict = {}) -> fx.Graph:
"""Optimize einsums in a ``torch.fx.Graph`` using ``opt_einsum``.
``graph`` must have shape information such as that populated by ``torch.fx.passes.shape_prop.ShapeProp``. The shapes are used for ``opt_einsum`` and the result is specific to the number of dimensions in the provided shapes ``opt_einsum``:
...while it will work for a set of arrays with the same ranks as the original shapes but differing sizes, it might no longer be optimal.
See the ``opt_einsum`` `documentation `_ for more details.
Args:
graph (fx.Graph): the graph to optimize
contract_kwargs: extra keyword arguments for ``opt_einsum.contract_path``.
Returns:
An optimized ``fx.Graph``.
"""
defaults = {
"optimize": "optimal",
}
defaults.update(contract_kwargs)
contract_kwargs = defaults
new_graph = fx.Graph()
tracer = fx.proxy.GraphAppendingTracer(new_graph)
# env keeps track of new injected nodes in addition to existing ones,
# making sure they get into new_graph
env = {}
node_processed: bool = False
for node in graph.nodes:
node_processed = False
if node.op == "call_function" and node.target in _EINSUM_FUNCS:
# Get shapes
shapes = [get_shape(a) for a in node.args[1:]]
if any(s is None for s in shapes):
warnings.warn(
f"einsum {repr(node)} lacked shape information; "
"not optimizing. "
"Did you forget to run ShapeProp on this graph?",
RuntimeWarning,
)
else:
# We have shapes, so:
# Determine the optimal contraction
path, path_info = opt_einsum.contract_path(
node.args[0], # the einstr
*shapes,
shapes=True,
**contract_kwargs,
)
# By wrapping the arguments with proxies,
# we can dispatch to opt_einsum and implicitly
# add it to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer=tracer) if isinstance(x, fx.Node) else x
for x in node.args
]
# Use _core_contract to avoid `len()` calls that
# fx can't deal with
output_proxy = _core_contract(
proxy_args[1:],
path_info.contraction_list,
backend="torch",
evaluate_constants=False,
)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
node_processed = True
if not node_processed:
# Default case: just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
new_graph.lint()
return new_graph
python-opt-einsum-fx-0.1.4/opt_einsum_fx/_script.py 0000664 0000000 0000000 00000003646 15046477414 0022467 0 ustar 00root root 0000000 0000000 from typing import Union
from packaging import version
import torch
from torch import fx
# see https://github.com/pytorch/pytorch/issues/53487
def jitable(obj: Union[fx.GraphModule, fx.Graph]) -> Union[fx.GraphModule, fx.Graph]:
"""Convert some torch calls into their TorchScript signatures.
In place. Currently deals with ``tensordot`` and ``permute``.
Args:
obj: the ``fx.Graph`` or ``fx.GraphModule`` to process.
Returns:
``obj``, modified in-place.
"""
if isinstance(obj, fx.GraphModule):
graph = obj.graph
else:
graph = obj
torch_is_ge_19: bool = version.parse(torch.__version__) >= version.parse("1.9.0")
for node in graph.nodes:
if node.op == "call_function":
if (
node.target == torch.tensordot
or node.target == torch.functional.tensordot
):
if "dims" in node.kwargs:
args = list(node.args)
kwargs = dict(node.kwargs)
dim_self, dim_other = kwargs.pop("dims")
assert len(args) == 2 # tensors 1 and 2
if torch_is_ge_19:
# In torch >= 1.9.0, they've corrected the torchscript interface
# to align with the python one:
args.append((list(dim_self), list(dim_other)))
else:
args.append(list(dim_self))
args.append(list(dim_other))
node.args = tuple(args)
node.kwargs = kwargs
elif node.op == "call_method":
if node.target == "permute":
self_arg, args = node.args[0], node.args[1:]
if not isinstance(args[0], list):
node.args = [self_arg, list(args)]
graph.lint()
if isinstance(obj, fx.GraphModule):
obj.recompile()
return obj
python-opt-einsum-fx-0.1.4/opt_einsum_fx/fx_utils.py 0000664 0000000 0000000 00000000466 15046477414 0022656 0 ustar 00root root 0000000 0000000 from typing import Optional
import torch
from torch import fx
def get_shape(n: fx.Node) -> Optional[torch.Size]:
"""Get the shape of a node after ``ShapeProp``"""
try:
return n.meta["tensor_meta"].shape
except KeyError:
return None
except AttributeError:
return None
python-opt-einsum-fx-0.1.4/pyproject.toml 0000664 0000000 0000000 00000000147 15046477414 0020500 0 ustar 00root root 0000000 0000000 [build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta" python-opt-einsum-fx-0.1.4/setup.cfg 0000664 0000000 0000000 00000000175 15046477414 0017406 0 ustar 00root root 0000000 0000000 [flake8]
max-line-length = 127
select = E,F,W,C
ignore = E226,E501,E741,E743,C901,W503
exclude = .eggs,*.egg,build,dist,docs
python-opt-einsum-fx-0.1.4/setup.py 0000664 0000000 0000000 00000001676 15046477414 0017306 0 ustar 00root root 0000000 0000000 import setuptools
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
setuptools.setup(
name="opt_einsum_fx",
version="0.1.4", # remember to update in `__init__.py` too!
author="Linux-cpp-lisp",
url="https://github.com/Linux-cpp-lisp/opt_einsum_fx",
description="Einsum optimization using opt_einsum and PyTorch FX",
long_description=long_description,
long_description_content_type="text/markdown",
license="MIT",
license_files="LICENSE",
project_urls={
"Bug Tracker": "https://github.com/Linux-cpp-lisp/opt_einsum_fx/issues",
},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
],
python_requires=">=3.6",
install_requires=["torch>=1.8.0", "opt_einsum", "packaging"],
packages=["opt_einsum_fx"],
)
python-opt-einsum-fx-0.1.4/tests/ 0000775 0000000 0000000 00000000000 15046477414 0016724 5 ustar 00root root 0000000 0000000 python-opt-einsum-fx-0.1.4/tests/conftest.py 0000664 0000000 0000000 00000001645 15046477414 0021131 0 ustar 00root root 0000000 0000000 import pytest
import torch
FLOAT_TOLERANCE = {
t: torch.as_tensor(v, dtype=t)
for t, v in {torch.float32: 1e-5, torch.float64: 1e-10}.items()
}
@pytest.fixture(scope="session", autouse=True, params=["float32", "float64"])
def float_tolerance(request):
"""Run all tests with various PyTorch default dtypes.
This is a session-wide, autouse fixture — you only need to request it explicitly if a test needs to know the tolerance for the current default dtype.
Returns
--------
A precision threshold to use for closeness tests.
"""
old_dtype = torch.get_default_dtype()
dtype = {"float32": torch.float32, "float64": torch.float64}[request.param]
torch.set_default_dtype(dtype)
yield FLOAT_TOLERANCE[dtype]
torch.set_default_dtype(old_dtype)
@pytest.fixture(scope="session")
def allclose(float_tolerance):
return lambda x, y: torch.allclose(x, y, atol=float_tolerance)
python-opt-einsum-fx-0.1.4/tests/test_einsum_optimizer.py 0000664 0000000 0000000 00000005752 15046477414 0023750 0 ustar 00root root 0000000 0000000 import pytest
import torch
import torch.fx
from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable, EfficientShapeProp
def einmatmul(x, y):
return torch.einsum("ij,jk->ik", x, y)
def eintrace(x, y):
# these indexings make it square
b = torch.einsum("ii", x[:, : x.shape[0]])
return torch.einsum("jj", y[:, : y.shape[0]]) * b
def fusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
return torch.einsum("ik,ij->i", z, x)
def fusable_w_scalars(x, y):
z = torch.einsum("ij,jk->ik", x, y) / 3.0
return 4.0 * torch.einsum("ik,ij->i", z, x)
def unfusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
# We use z as something besides an input to the second einsum, so it is unfusable
return torch.einsum("ik,ij->i", z, x) + z[:, 0]
def unfusable_w_scalars(x, y):
z = 2.7 * torch.einsum("ij,jk->ik", x, y)
# We use z as something besides an input to the second einsum, so it is unfusable
return torch.einsum("ik,ij->i", z, x) + 1.1 * z[:, 0]
def not_einsum(x, y):
# Try to trip it up with lots of scalar fusion but no einsums
return 3.0 * 2.7 * x.sum() + (4.6 / y.relu().sum())
def not_einsum2(x, y):
a = x.tanh().relu().sum() - y.sum()
b = 3.41 * y.sum().tanh()
return a - 6.7 * b
@pytest.fixture(
scope="module",
params=[
einmatmul,
eintrace,
fusable,
fusable_w_scalars,
unfusable,
unfusable_w_scalars,
not_einsum,
not_einsum2,
],
)
def einfunc(request):
return request.param
def test_optimize_einsums(einfunc, allclose):
x = torch.randn(3, 4)
y = torch.randn(4, 5)
func_res = einfunc(x, y)
func_fx = torch.fx.symbolic_trace(einfunc)
sp = EfficientShapeProp(func_fx)
sp.run(x, y)
func_fx_res = func_fx(x, y)
assert torch.all(func_res == func_fx_res)
graph_opt = optimize_einsums(func_fx.graph)
func_fx.graph = graph_opt
func_fx.recompile()
func_opt_res = func_fx(x, y)
assert allclose(func_opt_res, func_fx_res)
def test_optimize_einsums_full(einfunc, allclose):
x = torch.randn(3, 4)
y = torch.randn(4, 5)
func_res = einfunc(x, y)
func_opt = optimize_einsums_full(einfunc, (x, y))
assert allclose(func_res, func_opt(x, y))
def test_fallback():
# We only bother to test this for one function
einfunc = fusable
# If there is no shape propagation, it should warn
# and not do anything.
func_fx = torch.fx.symbolic_trace(einfunc)
old_code = func_fx.code
with pytest.warns(RuntimeWarning):
graph_opt = optimize_einsums(func_fx.graph)
func_fx.graph = graph_opt
func_fx.recompile()
assert old_code == func_fx.code
def test_torchscript(einfunc, allclose):
x = torch.randn(3, 4)
y = torch.randn(4, 5)
func_res = einfunc(x, y)
mod_opt = optimize_einsums_full(einfunc, (x, y))
mod_opt = jitable(mod_opt)
mod_opt = torch.jit.script(mod_opt)
func_opt_res = mod_opt(x, y)
assert allclose(func_opt_res, func_res)
python-opt-einsum-fx-0.1.4/tests/test_fuse.py 0000664 0000000 0000000 00000011252 15046477414 0021300 0 ustar 00root root 0000000 0000000 import pytest
import math
import operator
import torch
import torch.fx
from opt_einsum_fx import fuse_einsums, fuse_scalars, optimize_einsums_full
def test_einsum_fuse(allclose):
def fusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
return torch.einsum("ik,ij->i", z, x)
g = torch.fx.symbolic_trace(fusable)
new_graph = fuse_einsums(g.graph)
g.graph = new_graph
g.recompile()
x, y = torch.randn(3, 4), torch.randn(4, 5)
out_truth = fusable(x, y)
out_fused = g(x, y)
assert allclose(out_fused, out_truth)
def test_unfusable():
def unfusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
# We use z as something besides an input to the second einsum, so it is unfusable
return torch.einsum("ik,ij->i", z, x) + z[:, 0]
g = torch.fx.symbolic_trace(unfusable)
old_code = g.code
new_graph = fuse_einsums(g.graph)
g.graph = new_graph
g.recompile()
# Confirm numerical equivalence
x, y = torch.randn(3, 4), torch.randn(4, 5)
out_truth = unfusable(x, y)
out_fused = g(x, y)
# Here we use normal allclose --- since unfusable is unfusable,
# nothing should have changed.
assert torch.allclose(out_fused, out_truth)
# Confirm no fusion:
assert old_code == g.code
def test_doublefuse(allclose):
def doublefuse(a, b, c, d):
# quadruple matmul with a final transpose
e1 = torch.einsum("ij,jk->ik", a, b)
e2 = torch.einsum("ab,bc->ac", e1, c)
return torch.einsum("tr,ry->yt", e2, d)
g = torch.fx.symbolic_trace(doublefuse)
new_graph = fuse_einsums(g.graph)
g.graph = new_graph
g.recompile()
a, b, c, d = (
torch.randn(3, 4),
torch.randn(4, 5),
torch.randn(5, 2),
torch.randn(2, 3),
)
out_truth = doublefuse(a, b, c, d)
out_fused = g(a, b, c, d)
assert allclose(out_fused, out_truth)
def test_inconsistent():
def inconsistent(x, y):
z = torch.einsum("ij,jk->ik", x, y)
# Note that the dimension labels for z have the wrong length
return torch.einsum("i,ij->i", z, x)
g = torch.fx.symbolic_trace(inconsistent)
with pytest.raises(RuntimeError):
_ = fuse_einsums(g.graph)
def scalar_fusable1(x, y):
return 7.0 * torch.einsum("ij,jk->ik", x, y / 3) / 2
def scalar_fusable2(x, y):
return 4.0 * torch.einsum("ij,jk->ik", x, 2.0 * y / 3) / 2
def scalar_fusable3(x, y):
return 4.0 * torch.einsum("ij,jk->ik", x / 1.2, 1.7 * 2.0 * y / 3) / 2
def scalar_unfusable(x, y):
z = 3 * torch.einsum("ij,jk->ik", x, y) / 4.0
# We use z as something besides an input to the second einsum, so it is unfusable
return (2.0 * torch.einsum("ik,ij->i", z, x)) + z[:, 0]
def just_scalars(x, y):
return 3.0 * x
def just_many_scalars(x, y):
return 3.0 / 3.4 * x / 4.0
def in_place(x, y):
# This *shouldn't* be fused.
a = x.clone()
b = a.mul_(4.0)
return 3.0 * b
def unused(x, y):
b = 2.3 * x / 4.5 # noqa
return 4.6 * torch.einsum("ij,jk->ik", x, y)
def constants(x, y):
return math.pi * torch.einsum("ij,jk->ik", x, math.e * y / 3) / 2
# In all cases but unfusable, after fusion, the graph should have 5 nodes:
# two placeholders, one einsum, one mul, and one output
@pytest.mark.parametrize(
"func",
[
(scalar_fusable1, 5),
(scalar_fusable2, 5),
(scalar_fusable3, 5),
(
scalar_unfusable,
9, # two placeholders, one einsum one mul, one einsum one mul, one getitem, one sum, and one output = 9
),
(just_scalars, 4),
(just_many_scalars, 4),
(in_place, 6),
(constants, 5),
(unused, 6),
],
)
def test_scalar_fuse(allclose, func):
func, truth_num_nodes = func
g = torch.fx.symbolic_trace(func)
print("old graph\n", g.graph)
new_graph = fuse_scalars(g.graph)
print("new graph\n", new_graph)
g.graph = new_graph
assert len(g.graph.nodes) == truth_num_nodes
g.recompile()
x, y = torch.randn(3, 4), torch.randn(4, 5)
out_truth = func(x, y)
out_fused = g(x, y)
assert allclose(out_fused, out_truth)
def test_scalar_positioning(allclose):
def f(x, y, z):
return 0.784 * torch.einsum("ij,jk,kl->il", x, y, z)
x, y, z = torch.randn(2, 100), torch.randn(100, 2), torch.randn(2, 100)
# note that the smallest here is y
g = torch.fx.symbolic_trace(f)
print("old graph\n", g.graph)
g = optimize_einsums_full(g, (x, y, z))
print("new graph\n", g.graph)
# optimal placement is on the 2x2 intermediate
assert list(g.graph.nodes)[4].target == operator.mul
out_truth = f(x, y, z)
out_fused = g(x, y, z)
assert allclose(out_fused, out_truth)