pax_global_header00006660000000000000000000000064150464774140014526gustar00rootroot0000000000000052 comment=8fe83f83ed005c0e7342a4c69d1484e2d09fdbc3 python-opt-einsum-fx-0.1.4/000077500000000000000000000000001504647741400155625ustar00rootroot00000000000000python-opt-einsum-fx-0.1.4/.github/000077500000000000000000000000001504647741400171225ustar00rootroot00000000000000python-opt-einsum-fx-0.1.4/.github/workflows/000077500000000000000000000000001504647741400211575ustar00rootroot00000000000000python-opt-einsum-fx-0.1.4/.github/workflows/release.yml000066400000000000000000000012351504647741400233230ustar00rootroot00000000000000name: Upload Python Package on: release: types: [created] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip pip install setuptools wheel twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | python setup.py sdist bdist_wheel twine upload dist/* python-opt-einsum-fx-0.1.4/.github/workflows/tests.yml000066400000000000000000000021641504647741400230470ustar00rootroot00000000000000name: Check Syntax and Run Tests on: push: branches: - main pull_request: branches: - main jobs: build: runs-on: ubuntu-latest strategy: matrix: python-version: [3.6, 3.9] torch-version: [1.8.0, 1.10.0] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install flake8 run: | pip install flake8 - name: Lint with flake8 run: | flake8 . --count --show-source --statistics - name: Install dependencies env: TORCH: "${{ matrix.torch-version }}" GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | python -m pip install --upgrade pip pip install wheel pip install torch==${TORCH} torchvision torchaudio -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install . - name: Install pytest run: | pip install pytest - name: Test with pytest run: | pytest --doctest-modules --ignore=docs/ . python-opt-einsum-fx-0.1.4/.gitignore000066400000000000000000000034171504647741400175570ustar00rootroot00000000000000# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .vscodepython-opt-einsum-fx-0.1.4/.readthedocs.yaml000066400000000000000000000006231504647741400210120ustar00rootroot00000000000000# .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # Optionally set the version of Python and requirements required to build your docs python: version: 3.7 install: - method: pip path: .python-opt-einsum-fx-0.1.4/CHANGELOG.md000066400000000000000000000014501504647741400173730ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). Most recent change on the bottom. ## [Unreleased] ## 0.1.4 - 2021-11-7 ### Added - `opt_einsum_fx.__version__` - Partially symbolic shape propagation for efficient einsum optimization (#15) ## 0.1.3 - 2021-10-29 ### Added - PyTorch 1.10 compatability ### Fixed - Added `packaging` to dependency list ## 0.1.2 - 2021-06-28 ### Added - PyTorch 1.9 compatibility ## 0.1.1 - 2021-05-27 ### Added - Docs - PyPI package ### Fixed - `jitable` no longer makes some FX nodes' `.args` lists (technically not allowed) instead of tuples ## [0.1.0] - 2021-05-17python-opt-einsum-fx-0.1.4/LICENSE000066400000000000000000000020501504647741400165640ustar00rootroot00000000000000MIT License Copyright (c) 2021 Alby M. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. python-opt-einsum-fx-0.1.4/README.md000066400000000000000000000051411504647741400170420ustar00rootroot00000000000000# opt_einsum_fx [![Documentation Status](https://readthedocs.org/projects/opt-einsum-fx/badge/?version=latest)](https://opt-einsum-fx.readthedocs.io/en/latest/?badge=latest) Optimizing einsums and functions involving them using [`opt_einsum`](https://optimized-einsum.readthedocs.io/en/stable/) and PyTorch [FX](https://pytorch.org/docs/stable/fx.html) compute graphs. Issues, questions, PRs, and any thoughts about further optimizing these kinds of operations are welcome! For more information please see [the docs](https://opt-einsum-fx.readthedocs.io/en/stable/). ## Installation ### PyPI The latest release can be installed from PyPI: ```bash $ pip install opt_einsum_fx ``` ### Source To get the latest code, run: ```bash $ git clone https://github.com/Linux-cpp-lisp/opt_einsum_fx.git ``` and install it by running ```bash $ cd opt_einsum_fx/ $ pip install . ``` You can run the tests with ```bash $ pytest tests/ ``` ## Minimal example ```python import torch import torch.fx import opt_einsum_fx def einmatvecmul(a, b, vec): """Batched matrix-matrix-vector product using einsum""" return torch.einsum("zij,zjk,zk->zi", a, b, vec) graph_mod = torch.fx.symbolic_trace(einmatvecmul) print("Original code:\n", graph_mod.code) graph_opt = opt_einsum_fx.optimize_einsums_full( model=graph_mod, example_inputs=( torch.randn(7, 4, 5), torch.randn(7, 5, 3), torch.randn(7, 3) ) ) print("Optimized code:\n", graph_opt.code) ``` outputs ``` Original code: import torch def forward(self, a, b, vec): einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec); a = b = vec = None return einsum_1 Optimized code: import torch def forward(self, a, b, vec): einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b); vec = b = None einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a); einsum_1 = a = None return einsum_2 ``` We can measure the performance improvement (this is on a CPU): ```python from torch.utils.benchmark import Timer batch = 1000 a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8) g = {"f": graph_mod, "a": a, "b": b, "vec": vec} t_orig = Timer("f(a, b, vec)", globals=g) print(t_orig.timeit(10_000)) g["f"] = graph_opt t_opt = Timer("f(a, b, vec)", globals=g) print(t_opt.timeit(10_000)) ``` gives ~2x improvement: ``` f(a, b, vec) 276.58 us 1 measurement, 10000 runs , 1 thread f(a, b, vec) 118.84 us 1 measurement, 10000 runs , 1 thread ``` Depending on your function and dimensions you may see even larger improvements. ## License `opt_einsum_fx` is distributed under an [MIT license](LICENSE).python-opt-einsum-fx-0.1.4/docs/000077500000000000000000000000001504647741400165125ustar00rootroot00000000000000python-opt-einsum-fx-0.1.4/docs/Makefile000066400000000000000000000011721504647741400201530ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) python-opt-einsum-fx-0.1.4/docs/api.rst000066400000000000000000000001411504647741400200110ustar00rootroot00000000000000Full Reference ============== .. automodule:: opt_einsum_fx :members: :imported-members:python-opt-einsum-fx-0.1.4/docs/conf.py000066400000000000000000000036201504647741400200120ustar00rootroot00000000000000# Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) # -- Project information ----------------------------------------------------- project = "opt_einsum_fx" copyright = "2021, Linux-cpp-lisp" author = "Linux-cpp-lisp" # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_rtd_theme"] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] python-opt-einsum-fx-0.1.4/docs/index.rst000066400000000000000000000007431504647741400203570ustar00rootroot00000000000000.. opt_einsum_fx documentation master file, created by sphinx-quickstart on Wed May 26 15:06:08 2021. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to opt_einsum_fx's documentation! ========================================= .. toctree:: :maxdepth: 2 :caption: Contents: tutorial.rst api.rst Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` python-opt-einsum-fx-0.1.4/docs/make.bat000066400000000000000000000014331504647741400201200ustar00rootroot00000000000000@ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd python-opt-einsum-fx-0.1.4/docs/tutorial.rst000066400000000000000000000077321504647741400211200ustar00rootroot00000000000000Tutorial ======== ``opt_einsum_fx`` is a library for optimizng einsums and functions involving them using `opt_einsum `_ and `PyTorch FX `_ compute graphs. The library currently supports: - Fusing multiple einsums into one - Optimizing einsums using the `opt_einsum `_ library - Fusing multiplication and division with scalar constants, including fusing _through_ operations, like einsum, that commute with scalar multiplication. - Placing multiplication by fused scalar constants onto the smallest intermediate in a chain of operations that commute with scalar multiplication. ``opt_einsum_fx`` is based on `torch.fx `_, a framework for converting between PyTorch Python code and a programatically manipulable compute graph. To use this package, it must be possible to get your function or model as a `torch.fx.Graph`: the limitations of FX's symbolic tracing are discussed `here `_. Minimal example --------------- .. code-block:: python import torch import torch.fx import opt_einsum_fx def einmatvecmul(a, b, vec): """Batched matrix-matrix-vector product using einsum""" return torch.einsum("zij,zjk,zk->zi", a, b, vec) graph_mod = torch.fx.symbolic_trace(einmatvecmul) print("# Original code:\n", graph_mod.code) graph_opt = opt_einsum_fx.optimize_einsums_full( model=graph_mod, example_inputs=( torch.randn(7, 4, 5), torch.randn(7, 5, 3), torch.randn(7, 3) ) ) print("# Optimized code:\n", graph_opt.code) outputs: .. code-block:: python # Original code: import torch def forward(self, a, b, vec): einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec); a = b = vec = None return einsum_1 # Optimized code: import torch def forward(self, a, b, vec): einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b); vec = b = None einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a); einsum_1 = a = None return einsum_2 The ``optimize_einsums_full`` function has four passes: 1. Scalar accumulation --- use the multilinearity of einsum to fuse all constant coefficients and divisors of operands and outputs 2. Fusing einsums --- gives greater flexibility to (3) 3. Optimized contraction with ``opt_einsum`` 4. Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results We can measure the performance improvement (this is on a CPU): .. code-block:: python from torch.utils.benchmark import Timer batch = 1000 a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8) g = {"f": graph_mod, "a": a, "b": b, "vec": vec} t_orig = Timer("f(a, b, vec)", globals=g) print(t_orig.timeit(10_000)) g["f"] = graph_opt t_opt = Timer("f(a, b, vec)", globals=g) print(t_opt.timeit(10_000)) gives ~2x improvement: .. code-block:: none f(a, b, vec) 276.58 us 1 measurement, 10000 runs , 1 thread f(a, b, vec) 118.84 us 1 measurement, 10000 runs , 1 thread Depending on your function and dimensions you may see even larger improvements. JIT --- Currently, pure Python and TorchScript have different call signatures for `torch.tensordot` and `torch.permute`, both of which can appear in optimized einsums: .. code-block:: python graph_script = torch.jit.script(graph_opt) # => RuntimeError: Arguments for call are not valid... A function is provided to convert ``torch.fx.GraphModule`` s containing these operations from their Python signatures — the default — to a TorchScript compatible form: .. code-block:: python graph_script = torch.jit.script(opt_einsum_fx.jitable(graph_opt))python-opt-einsum-fx-0.1.4/opt_einsum_fx/000077500000000000000000000000001504647741400204415ustar00rootroot00000000000000python-opt-einsum-fx-0.1.4/opt_einsum_fx/__init__.py000066400000000000000000000005531504647741400225550ustar00rootroot00000000000000__version__ = "0.1.4" from ._script import jitable from ._opt_ein import optimize_einsums, optimize_einsums_full from ._fuse import fuse_einsums, fuse_scalars from ._efficient_shape_prop import EfficientShapeProp __all__ = [ "jitable", "optimize_einsums", "optimize_einsums_full", "fuse_einsums", "fuse_scalars", "EfficientShapeProp", ] python-opt-einsum-fx-0.1.4/opt_einsum_fx/_efficient_shape_prop.py000066400000000000000000000072371504647741400253370ustar00rootroot00000000000000from typing import Any, NamedTuple import opt_einsum import torch from torch.fx.node import Node from ._fuse import _EINSUM_FUNCS class SimpleMeta(NamedTuple): """ The full ShapeProp defines and uses a NamedTuple to store a whole bunch of metadata about the tensors going into and out of the Node op. But we don't have most of that info, and anyway, I don't think most of it's used in opt_einsum or opt_einsum_fx. (These are only concerned with computing a summation order.) Rather than give dummy or default values, which I only *assume* would be fine, I'm defining a NamedTuple with only the values we actually know. So if I'm wrong we will get a very clear error message, rather than some invisible error. """ shape: torch.Size dtype: torch.dtype class EfficientShapeProp(torch.fx.Interpreter): """ Like ShapeProp, traverses a graph Node-by-Node and records the shape and type of the result into each Node. Except we treat 'einsum' as a special case. We don't actually execute 'einsum' on tensors, since the einsums will typically not be optimized yet (ShapeProp is called before optimization), and inefficient summation order can create enormous intermediate tensors, which often creates needless out-of-memory errors. So we override 'run_node' only for 'einsums'. It's straightforward to determine the shape of the result just from the output indices. (The call to opt_einsum that will typically follow this, also doesn't actually build the tensors during its exploration.) """ def run_node(self, n: Node) -> Any: if n.op == "call_function" and n.target in _EINSUM_FUNCS: args, kwargs = self.fetch_args_kwargs_from_env(n) equation, *operands = args shapes = [op.shape for op in operands] assert len({op.dtype for op in operands}) == 1 meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype) result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape) elif n.op == "call_function" and n.target == torch.tensordot: args, kwargs = self.fetch_args_kwargs_from_env(n) shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]] shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]] assert len({op.dtype for op in args}) == 1 meta = SimpleMeta(shape_a + shape_b, args[0].dtype) result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape) else: result = super().run_node(n) if isinstance(result, torch.Tensor): meta = SimpleMeta(result.shape, result.dtype) else: meta = None n.meta = dict() n.meta['tensor_meta'] = meta n.meta['type'] = type(result) return result def propagate(self, *args): return super().run(*args) def einsum_shape(subscripts, *shapes): """ Given an einsum equation and input shapes, returns the output shape of the einsum. Args: subscripts: the einsum formula shapes: the input shapes """ Shaped = NamedTuple('Shaped', [('shape', tuple)]) input_subscripts, output_subscript, _ = opt_einsum.parser.parse_einsum_input( (subscripts,) + tuple(Shaped(shape) for shape in shapes) ) dims = { i: dim for ii, shape in zip(input_subscripts.split(','), shapes) for i, dim in zip(ii, shape) } return tuple(dims[i] for i in output_subscript) python-opt-einsum-fx-0.1.4/opt_einsum_fx/_fuse.py000066400000000000000000000327261504647741400221260ustar00rootroot00000000000000from typing import List, Optional, Tuple import string import copy import operator import numbers import torch from torch import fx from opt_einsum.parser import find_output_str from .fx_utils import get_shape _EINSUM_FUNCS = {torch.functional.einsum, torch.einsum} # == Einsum fusion == def _get_einstrs(einstr: str) -> Tuple[List[str], str]: if "..." in einstr: raise NotImplementedError("Ellipsis `...` in einsum string not supported yet") tmp = einstr.split("->") if len(tmp) == 1: ops = tmp[0] out = find_output_str(ops) elif len(tmp) == 2: ops, out = tmp else: raise ValueError(f"Invalid einstr {einstr}") return ops.split(","), out def fuse_einsums(graph: fx.Graph, in_place: bool = False) -> fx.Graph: """Fuse einsums when possible. When the output of one einsum is only used as an operand in another einsum, the two einsums can be fused into one. Example: .. code-block:: python def fusable(x, y): z = torch.einsum("ij,jk->ik", x, y) return torch.einsum("ik,ij->i", z, x) g = torch.fx.symbolic_trace(fusable) print(fuse_einsums(g.graph).python_code("")) gives:: import torch def forward(self, x, y): einsum_2 = torch.functional.einsum('ib,bk,ij->i', x, y, x); x = y = None return einsum_2 Args: graph: the graph to process. in_place (bool, optional): whether to process ``graph`` in place. Returns: The graph with fused einsums. """ if not in_place: graph = copy.deepcopy(graph) for node in graph.nodes: if node.op == "call_function" and node.target in _EINSUM_FUNCS: our_inp_einstrs, our_out_einstr = _get_einstrs(node.args[0]) assert len(our_inp_einstrs) == len(node.args) - 1 avail_letters = iter( set(string.ascii_lowercase) - set.union(*(set(e) for e in our_inp_einstrs)) ) new_our_einstrs = [] new_our_args = [] we_fused_nodes = [] # Iterate over operands for inp_idex, inp in enumerate(node.args[1:]): if ( inp.op == "call_function" and inp.target in _EINSUM_FUNCS and len(inp.users) == 1 ): # This operand is the output of another einsum, and is not used by any other operation # As a result, we can fuse it its_inp_einstrs, its_out_einstr = _get_einstrs(inp.args[0]) if len(its_out_einstr) != len(our_inp_einstrs[inp_idex]): raise RuntimeError( f"Inconsistent rank: einsum `{node}`'s input {inp_idex} is the result of einsum {inp}; the output of `{inp}` is labeled `{its_out_einstr}` (rank {len(its_out_einstr)}), but the corresponding input of `{node}` is labeled `{our_inp_einstrs[inp_idex]}` (rank {len(our_inp_einstrs[inp_idex])})" ) # First, we need to figure out which of its output dimensions correspond to our dimensions: its_dim_to_ours = dict( zip(its_out_einstr, our_inp_einstrs[inp_idex]) ) # assign any labels that don't show up in the output of the previous einsum --- and thus dont have labels in the current einsum --- to new letters its_remaining_labels = set.union( *(set(e) for e in its_inp_einstrs) ) - set(its_dim_to_ours.keys()) try: its_dim_to_ours.update( dict((i, next(avail_letters)) for i in its_remaining_labels) ) except StopIteration: # We ran out of letters raise NotImplementedError( f"At einsum {node}, ran out of letters when trying to fuse parameter einsum {inp}. A fallback for this case is not yet implimented." ) else: # We had enough letters, finish adding the fuse del its_remaining_labels new_our_args.extend(inp.args[1:]) new_our_einstrs.extend( "".join(its_dim_to_ours[d] for d in es) for es in its_inp_einstrs ) we_fused_nodes.append(inp) else: # This argument is not from an einsum, or is from an einsum that is used elsewhere as well # Thus we just pass it through new_our_einstrs.append(our_inp_einstrs[inp_idex]) new_our_args.append(inp) # -- end iter over prev einsum inputs -- # Set the new values for the einstrs node.args = (f"{','.join(new_our_einstrs)}->{our_out_einstr}",) + tuple( new_our_args ) # Remove fused inputs for to_remove in we_fused_nodes: graph.erase_node(to_remove) # -- end case for einsum nodes -- # -- end iter over nodes -- return graph # == Scalar fusion == # # Note that in general we do not support scalar fusion through in-place operations; it complicates following things through the compute graph too much # TODO: ^ ??? # TODO: should the accumulation of constants happen in more than double precision? def _get_node_and_scalar(node: fx.Node) -> Tuple[fx.Node, Optional[numbers.Number]]: """Get a multiplicative scalar for an operation, if applicable.""" # This supports in-place *= and /= because fx traces them as normal operator.mul/div. if node.op == "call_function": if node.target == operator.mul or node.target == torch.mul: if isinstance(node.args[0], numbers.Number): return node.args[1], node.args[0] elif isinstance(node.args[1], numbers.Number): return node.args[0], node.args[1] elif node.target == operator.truediv or node.target == torch.div: if isinstance(node.args[1], numbers.Number): return node.args[0], 1.0 / node.args[1] elif node.op == "call_method": # TODO: this could _technically_ be wrong if the nodes `self` argument is not a (proxy to) a Tensor if node.target == "mul": if isinstance(node.args[1], numbers.Number): return node.args[0], node.args[1] elif node.target == "div": if isinstance(node.args[1], numbers.Number): return node.args[0], 1.0 / node.args[1] return node, None # Operations that are (almost) "multilinear", in the sense that they commute with scalar multiplication of their operands SCALAR_COMMUTE_OPS = [ torch.einsum, torch.functional.einsum, torch.tensordot, torch.functional.tensordot, "permute", # "reshape", "mul", "div", operator.mul, operator.truediv, ] def prod(x): """Compute the product of a sequence.""" out = 1 for a in x: out *= a return out def fuse_scalars(graph: fx.Graph, in_place: bool = False) -> fx.Graph: """Use the multilinearity of einsum to unify and remove constant scalars around einsums. Args: graph: the graph to process. in_place (bool, optional): whether to process ``graph`` in place. Returns: The graph with fused scalars. """ if not in_place: graph = copy.deepcopy(graph) # Clear any previous state this graph has for node in graph.nodes: if hasattr(node, "in_lin_chain"): delattr(node, "in_lin_chain") # Find chains of multilinear ops seen_nodes = set() linear_chains = [] for node in graph.nodes: if id(node) in seen_nodes: continue # Determine a linear chain cur_linear_chain = [] while ( id(node) not in seen_nodes and getattr(node, "target", None) in SCALAR_COMMUTE_OPS ): seen_nodes.add(id(node)) node.in_lin_chain = len(linear_chains) cur_linear_chain.append(node) # Continue building the chain regardless, since the merger uses this users = list(node.users.keys()) if len(users) > 0: # Get the next node in the chain node = users[0] else: # This isn't used in the graph at all, break the chain node = None if len(users) != 1: # End this chain break # If the next user, which is now in node, was seen but is itself in a linear chain, this means we merge them # TODO: thoroughly test this if hasattr(node, "in_lin_chain") and len(cur_linear_chain) > 0: # Merge merge_into = node.in_lin_chain for n in cur_linear_chain: n.in_lin_chain = merge_into linear_chains[merge_into].extend(cur_linear_chain) else: # This is a new chain linear_chains.append(cur_linear_chain) # Accumulate scalars in them scalars = [] for lin_chain_i, lin_chain in enumerate(linear_chains): if len(lin_chain) < 2: # There's nothing to do here: either the chain is empty, # or there's only one operation — even if its a scalar multiplication, # theres nothing for us to do with it scalars.append(None) continue # Accumulate scalars scalar_node_idexes = [] total_scalar = 1.0 for node_i, node in enumerate(lin_chain): new_node, scalar = _get_node_and_scalar(node) if scalar is not None: total_scalar *= scalar scalar_node_idexes.append(node_i) is_all_scalars = len(scalar_node_idexes) == len(lin_chain) # Remove scalar nodes for node_i in scalar_node_idexes: node = lin_chain[node_i] new_node, scalar = _get_node_and_scalar(node) assert scalar is not None if is_all_scalars and node_i == len(lin_chain) - 1: # If it's all scalars, we just put the total_scalar into the last operation # and don't save a scalar for later with graph.inserting_after(node): new_node = graph.call_function( operator.mul, (total_scalar, new_node), ) total_scalar = None node.replace_all_uses_with(new_node) graph.erase_node(node) # Save the scalar for this chain scalars.append(total_scalar) # Remove all of the removed scalar operations from the lin chain # See https://stackoverflow.com/a/11303234/1008938 for index in sorted( (scalar_node_idexes[:-1] if is_all_scalars else scalar_node_idexes), reverse=True, ): del lin_chain[index] del seen_nodes # Make sure everything is still OK graph.lint() # Now we have chains without scalar operations; we can go through and add back in the scalars in the optimal place for lin_chain_i, lin_chain in enumerate(linear_chains): if ( len(lin_chain) == 0 or scalars[lin_chain_i] == 1.0 or scalars[lin_chain_i] is None ): # Nothing to do with an empty chain # No reason to add back a scalar that does nothing # None signals don't process from above continue # Find the smallest argument or the output smallest_node_i = None smallest_arg_i = None smallest_size = float("inf") for node_i, node in enumerate(lin_chain): for arg_i, arg in enumerate(node.args): if not isinstance(arg, fx.Node): continue shape = get_shape(arg) if shape is not None and prod(shape) < smallest_size: smallest_node_i = node_i smallest_arg_i = arg_i smallest_size = prod(shape) # Put the accumulated scalar on a node if (smallest_node_i is None) or ( get_shape(lin_chain[-1]) is not None and prod(get_shape(lin_chain[-1])) < smallest_size ): # The output is the smallest, put it there # OR there was no smallest argument, put it on the end of the chain with graph.inserting_after(lin_chain[-1]): new_node = graph.call_function(operator.mul, tuple()) # placeholder lin_chain[-1].replace_all_uses_with(new_node) new_node.args = (lin_chain[-1], scalars[lin_chain_i]) else: # The smallest was someone's arg, so we replace that with a scalar multiplication: with graph.inserting_before(lin_chain[smallest_node_i]): new_arg = graph.call_function( operator.mul, ( lin_chain[smallest_node_i].args[smallest_arg_i], scalars[lin_chain_i], ), ) new_args = list(lin_chain[smallest_node_i].args) new_args[smallest_arg_i] = new_arg lin_chain[smallest_node_i].args = tuple(new_args) graph.lint() return graph python-opt-einsum-fx-0.1.4/opt_einsum_fx/_opt_ein.py000066400000000000000000000147751504647741400226250ustar00rootroot00000000000000from typing import Callable, Union import warnings import torch from torch import fx from ._efficient_shape_prop import EfficientShapeProp as ShapeProp import opt_einsum from opt_einsum.contract import _core_contract from ._fuse import fuse_einsums, fuse_scalars, _EINSUM_FUNCS from .fx_utils import get_shape def optimize_einsums_full( model: Union[torch.nn.Module, Callable, fx.Graph], example_inputs: tuple, contract_kwargs: dict = {}, tracer_class: type = fx.Tracer, ) -> Union[fx.GraphModule, fx.Graph]: """Optimize einsums in ``model`` for ``example_inputs``. All of the restrictions of ``torch.fx`` symbolic tracing apply. Applies, in order, four optimizations: 1. Scalar accumulation --- use the multilinearity of einsum to collect all constant coefficients and divisors of operands and outputs 2. Fusing einsums --- gives greater flexibility to (3) 3. Optimized contraction with ``opt_einsum``. 4. Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results Args: model (torch.nn.Module or callable or fx.Graph): the model, function, or ``fx.Graph`` to optimize. example_inputs (tuple): arguments to ``model`` whose shapes will determine the einsum optimizations. tracer_class (type, optional): the tracer class to use to turn ``model`` into an ``fx.Graph`` if it isn't already an ``fx.GraphModule`` or ``fx.Graph``. Returns: An optimized ``fx.GraphModule``, or if ``model`` is an ``fx.Graph``, an optimized ``fx.Graph``. """ output_graph = False if isinstance(model, fx.GraphModule): graph: fx.Graph = model.graph elif isinstance(model, fx.Graph): graph: fx.Graph = model model = torch.nn.Module() output_graph = True else: tracer: fx.Tracer = tracer_class() graph: fx.Graph = tracer.trace(model) model = tracer.root # 1. Scalar accumulation # without shape information, this just accumulates scalars and moves them to the end of chains of linear operations graph = fuse_scalars(graph) # 2. Fuse any einsums we can # This gives opt_einsum the most freedom possible to rearange things # Since we already moved scalars to the end of chains of linear operations, any scalars between linear operations should already have been moved graph = fuse_einsums(graph, in_place=True) out_mod = fx.GraphModule(model, graph) # 3. Shape propagation sp = ShapeProp(out_mod) sp.run(*example_inputs) # 4. Optimize einsums out_mod.graph = optimize_einsums(out_mod.graph, contract_kwargs) out_mod.recompile() # 5. Shape prop (again) # We need shapes to put the scalars in the best place sp = ShapeProp(out_mod) sp.run(*example_inputs) # 6. Final scalar fusion to move scalars out_mod.graph = fuse_scalars(out_mod.graph, in_place=True) if output_graph: return out_mod.graph else: out_mod.recompile() return out_mod # Based on "Proxy Retracing" example in https://pytorch.org/docs/stable/fx.html def optimize_einsums(graph: fx.Graph, contract_kwargs: dict = {}) -> fx.Graph: """Optimize einsums in a ``torch.fx.Graph`` using ``opt_einsum``. ``graph`` must have shape information such as that populated by ``torch.fx.passes.shape_prop.ShapeProp``. The shapes are used for ``opt_einsum`` and the result is specific to the number of dimensions in the provided shapes ``opt_einsum``: ...while it will work for a set of arrays with the same ranks as the original shapes but differing sizes, it might no longer be optimal. See the ``opt_einsum`` `documentation `_ for more details. Args: graph (fx.Graph): the graph to optimize contract_kwargs: extra keyword arguments for ``opt_einsum.contract_path``. Returns: An optimized ``fx.Graph``. """ defaults = { "optimize": "optimal", } defaults.update(contract_kwargs) contract_kwargs = defaults new_graph = fx.Graph() tracer = fx.proxy.GraphAppendingTracer(new_graph) # env keeps track of new injected nodes in addition to existing ones, # making sure they get into new_graph env = {} node_processed: bool = False for node in graph.nodes: node_processed = False if node.op == "call_function" and node.target in _EINSUM_FUNCS: # Get shapes shapes = [get_shape(a) for a in node.args[1:]] if any(s is None for s in shapes): warnings.warn( f"einsum {repr(node)} lacked shape information; " "not optimizing. " "Did you forget to run ShapeProp on this graph?", RuntimeWarning, ) else: # We have shapes, so: # Determine the optimal contraction path, path_info = opt_einsum.contract_path( node.args[0], # the einstr *shapes, shapes=True, **contract_kwargs, ) # By wrapping the arguments with proxies, # we can dispatch to opt_einsum and implicitly # add it to the Graph by symbolically tracing it. proxy_args = [ fx.Proxy(env[x.name], tracer=tracer) if isinstance(x, fx.Node) else x for x in node.args ] # Use _core_contract to avoid `len()` calls that # fx can't deal with output_proxy = _core_contract( proxy_args[1:], path_info.contraction_list, backend="torch", evaluate_constants=False, ) # Operations on `Proxy` always yield new `Proxy`s, and the # return value of our decomposition rule is no exception. # We need to extract the underlying `Node` from the `Proxy` # to use it in subsequent iterations of this transform. new_node = output_proxy.node env[node.name] = new_node node_processed = True if not node_processed: # Default case: just copy the node over into the new graph. new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node new_graph.lint() return new_graph python-opt-einsum-fx-0.1.4/opt_einsum_fx/_script.py000066400000000000000000000036461504647741400224670ustar00rootroot00000000000000from typing import Union from packaging import version import torch from torch import fx # see https://github.com/pytorch/pytorch/issues/53487 def jitable(obj: Union[fx.GraphModule, fx.Graph]) -> Union[fx.GraphModule, fx.Graph]: """Convert some torch calls into their TorchScript signatures. In place. Currently deals with ``tensordot`` and ``permute``. Args: obj: the ``fx.Graph`` or ``fx.GraphModule`` to process. Returns: ``obj``, modified in-place. """ if isinstance(obj, fx.GraphModule): graph = obj.graph else: graph = obj torch_is_ge_19: bool = version.parse(torch.__version__) >= version.parse("1.9.0") for node in graph.nodes: if node.op == "call_function": if ( node.target == torch.tensordot or node.target == torch.functional.tensordot ): if "dims" in node.kwargs: args = list(node.args) kwargs = dict(node.kwargs) dim_self, dim_other = kwargs.pop("dims") assert len(args) == 2 # tensors 1 and 2 if torch_is_ge_19: # In torch >= 1.9.0, they've corrected the torchscript interface # to align with the python one: args.append((list(dim_self), list(dim_other))) else: args.append(list(dim_self)) args.append(list(dim_other)) node.args = tuple(args) node.kwargs = kwargs elif node.op == "call_method": if node.target == "permute": self_arg, args = node.args[0], node.args[1:] if not isinstance(args[0], list): node.args = [self_arg, list(args)] graph.lint() if isinstance(obj, fx.GraphModule): obj.recompile() return obj python-opt-einsum-fx-0.1.4/opt_einsum_fx/fx_utils.py000066400000000000000000000004661504647741400226560ustar00rootroot00000000000000from typing import Optional import torch from torch import fx def get_shape(n: fx.Node) -> Optional[torch.Size]: """Get the shape of a node after ``ShapeProp``""" try: return n.meta["tensor_meta"].shape except KeyError: return None except AttributeError: return None python-opt-einsum-fx-0.1.4/pyproject.toml000066400000000000000000000001471504647741400205000ustar00rootroot00000000000000[build-system] requires = [ "setuptools>=42", "wheel" ] build-backend = "setuptools.build_meta"python-opt-einsum-fx-0.1.4/setup.cfg000066400000000000000000000001751504647741400174060ustar00rootroot00000000000000[flake8] max-line-length = 127 select = E,F,W,C ignore = E226,E501,E741,E743,C901,W503 exclude = .eggs,*.egg,build,dist,docs python-opt-einsum-fx-0.1.4/setup.py000066400000000000000000000016761504647741400173060ustar00rootroot00000000000000import setuptools with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setuptools.setup( name="opt_einsum_fx", version="0.1.4", # remember to update in `__init__.py` too! author="Linux-cpp-lisp", url="https://github.com/Linux-cpp-lisp/opt_einsum_fx", description="Einsum optimization using opt_einsum and PyTorch FX", long_description=long_description, long_description_content_type="text/markdown", license="MIT", license_files="LICENSE", project_urls={ "Bug Tracker": "https://github.com/Linux-cpp-lisp/opt_einsum_fx/issues", }, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Development Status :: 4 - Beta", ], python_requires=">=3.6", install_requires=["torch>=1.8.0", "opt_einsum", "packaging"], packages=["opt_einsum_fx"], ) python-opt-einsum-fx-0.1.4/tests/000077500000000000000000000000001504647741400167245ustar00rootroot00000000000000python-opt-einsum-fx-0.1.4/tests/conftest.py000066400000000000000000000016451504647741400211310ustar00rootroot00000000000000import pytest import torch FLOAT_TOLERANCE = { t: torch.as_tensor(v, dtype=t) for t, v in {torch.float32: 1e-5, torch.float64: 1e-10}.items() } @pytest.fixture(scope="session", autouse=True, params=["float32", "float64"]) def float_tolerance(request): """Run all tests with various PyTorch default dtypes. This is a session-wide, autouse fixture — you only need to request it explicitly if a test needs to know the tolerance for the current default dtype. Returns -------- A precision threshold to use for closeness tests. """ old_dtype = torch.get_default_dtype() dtype = {"float32": torch.float32, "float64": torch.float64}[request.param] torch.set_default_dtype(dtype) yield FLOAT_TOLERANCE[dtype] torch.set_default_dtype(old_dtype) @pytest.fixture(scope="session") def allclose(float_tolerance): return lambda x, y: torch.allclose(x, y, atol=float_tolerance) python-opt-einsum-fx-0.1.4/tests/test_einsum_optimizer.py000066400000000000000000000057521504647741400237500ustar00rootroot00000000000000import pytest import torch import torch.fx from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable, EfficientShapeProp def einmatmul(x, y): return torch.einsum("ij,jk->ik", x, y) def eintrace(x, y): # these indexings make it square b = torch.einsum("ii", x[:, : x.shape[0]]) return torch.einsum("jj", y[:, : y.shape[0]]) * b def fusable(x, y): z = torch.einsum("ij,jk->ik", x, y) return torch.einsum("ik,ij->i", z, x) def fusable_w_scalars(x, y): z = torch.einsum("ij,jk->ik", x, y) / 3.0 return 4.0 * torch.einsum("ik,ij->i", z, x) def unfusable(x, y): z = torch.einsum("ij,jk->ik", x, y) # We use z as something besides an input to the second einsum, so it is unfusable return torch.einsum("ik,ij->i", z, x) + z[:, 0] def unfusable_w_scalars(x, y): z = 2.7 * torch.einsum("ij,jk->ik", x, y) # We use z as something besides an input to the second einsum, so it is unfusable return torch.einsum("ik,ij->i", z, x) + 1.1 * z[:, 0] def not_einsum(x, y): # Try to trip it up with lots of scalar fusion but no einsums return 3.0 * 2.7 * x.sum() + (4.6 / y.relu().sum()) def not_einsum2(x, y): a = x.tanh().relu().sum() - y.sum() b = 3.41 * y.sum().tanh() return a - 6.7 * b @pytest.fixture( scope="module", params=[ einmatmul, eintrace, fusable, fusable_w_scalars, unfusable, unfusable_w_scalars, not_einsum, not_einsum2, ], ) def einfunc(request): return request.param def test_optimize_einsums(einfunc, allclose): x = torch.randn(3, 4) y = torch.randn(4, 5) func_res = einfunc(x, y) func_fx = torch.fx.symbolic_trace(einfunc) sp = EfficientShapeProp(func_fx) sp.run(x, y) func_fx_res = func_fx(x, y) assert torch.all(func_res == func_fx_res) graph_opt = optimize_einsums(func_fx.graph) func_fx.graph = graph_opt func_fx.recompile() func_opt_res = func_fx(x, y) assert allclose(func_opt_res, func_fx_res) def test_optimize_einsums_full(einfunc, allclose): x = torch.randn(3, 4) y = torch.randn(4, 5) func_res = einfunc(x, y) func_opt = optimize_einsums_full(einfunc, (x, y)) assert allclose(func_res, func_opt(x, y)) def test_fallback(): # We only bother to test this for one function einfunc = fusable # If there is no shape propagation, it should warn # and not do anything. func_fx = torch.fx.symbolic_trace(einfunc) old_code = func_fx.code with pytest.warns(RuntimeWarning): graph_opt = optimize_einsums(func_fx.graph) func_fx.graph = graph_opt func_fx.recompile() assert old_code == func_fx.code def test_torchscript(einfunc, allclose): x = torch.randn(3, 4) y = torch.randn(4, 5) func_res = einfunc(x, y) mod_opt = optimize_einsums_full(einfunc, (x, y)) mod_opt = jitable(mod_opt) mod_opt = torch.jit.script(mod_opt) func_opt_res = mod_opt(x, y) assert allclose(func_opt_res, func_res) python-opt-einsum-fx-0.1.4/tests/test_fuse.py000066400000000000000000000112521504647741400213000ustar00rootroot00000000000000import pytest import math import operator import torch import torch.fx from opt_einsum_fx import fuse_einsums, fuse_scalars, optimize_einsums_full def test_einsum_fuse(allclose): def fusable(x, y): z = torch.einsum("ij,jk->ik", x, y) return torch.einsum("ik,ij->i", z, x) g = torch.fx.symbolic_trace(fusable) new_graph = fuse_einsums(g.graph) g.graph = new_graph g.recompile() x, y = torch.randn(3, 4), torch.randn(4, 5) out_truth = fusable(x, y) out_fused = g(x, y) assert allclose(out_fused, out_truth) def test_unfusable(): def unfusable(x, y): z = torch.einsum("ij,jk->ik", x, y) # We use z as something besides an input to the second einsum, so it is unfusable return torch.einsum("ik,ij->i", z, x) + z[:, 0] g = torch.fx.symbolic_trace(unfusable) old_code = g.code new_graph = fuse_einsums(g.graph) g.graph = new_graph g.recompile() # Confirm numerical equivalence x, y = torch.randn(3, 4), torch.randn(4, 5) out_truth = unfusable(x, y) out_fused = g(x, y) # Here we use normal allclose --- since unfusable is unfusable, # nothing should have changed. assert torch.allclose(out_fused, out_truth) # Confirm no fusion: assert old_code == g.code def test_doublefuse(allclose): def doublefuse(a, b, c, d): # quadruple matmul with a final transpose e1 = torch.einsum("ij,jk->ik", a, b) e2 = torch.einsum("ab,bc->ac", e1, c) return torch.einsum("tr,ry->yt", e2, d) g = torch.fx.symbolic_trace(doublefuse) new_graph = fuse_einsums(g.graph) g.graph = new_graph g.recompile() a, b, c, d = ( torch.randn(3, 4), torch.randn(4, 5), torch.randn(5, 2), torch.randn(2, 3), ) out_truth = doublefuse(a, b, c, d) out_fused = g(a, b, c, d) assert allclose(out_fused, out_truth) def test_inconsistent(): def inconsistent(x, y): z = torch.einsum("ij,jk->ik", x, y) # Note that the dimension labels for z have the wrong length return torch.einsum("i,ij->i", z, x) g = torch.fx.symbolic_trace(inconsistent) with pytest.raises(RuntimeError): _ = fuse_einsums(g.graph) def scalar_fusable1(x, y): return 7.0 * torch.einsum("ij,jk->ik", x, y / 3) / 2 def scalar_fusable2(x, y): return 4.0 * torch.einsum("ij,jk->ik", x, 2.0 * y / 3) / 2 def scalar_fusable3(x, y): return 4.0 * torch.einsum("ij,jk->ik", x / 1.2, 1.7 * 2.0 * y / 3) / 2 def scalar_unfusable(x, y): z = 3 * torch.einsum("ij,jk->ik", x, y) / 4.0 # We use z as something besides an input to the second einsum, so it is unfusable return (2.0 * torch.einsum("ik,ij->i", z, x)) + z[:, 0] def just_scalars(x, y): return 3.0 * x def just_many_scalars(x, y): return 3.0 / 3.4 * x / 4.0 def in_place(x, y): # This *shouldn't* be fused. a = x.clone() b = a.mul_(4.0) return 3.0 * b def unused(x, y): b = 2.3 * x / 4.5 # noqa return 4.6 * torch.einsum("ij,jk->ik", x, y) def constants(x, y): return math.pi * torch.einsum("ij,jk->ik", x, math.e * y / 3) / 2 # In all cases but unfusable, after fusion, the graph should have 5 nodes: # two placeholders, one einsum, one mul, and one output @pytest.mark.parametrize( "func", [ (scalar_fusable1, 5), (scalar_fusable2, 5), (scalar_fusable3, 5), ( scalar_unfusable, 9, # two placeholders, one einsum one mul, one einsum one mul, one getitem, one sum, and one output = 9 ), (just_scalars, 4), (just_many_scalars, 4), (in_place, 6), (constants, 5), (unused, 6), ], ) def test_scalar_fuse(allclose, func): func, truth_num_nodes = func g = torch.fx.symbolic_trace(func) print("old graph\n", g.graph) new_graph = fuse_scalars(g.graph) print("new graph\n", new_graph) g.graph = new_graph assert len(g.graph.nodes) == truth_num_nodes g.recompile() x, y = torch.randn(3, 4), torch.randn(4, 5) out_truth = func(x, y) out_fused = g(x, y) assert allclose(out_fused, out_truth) def test_scalar_positioning(allclose): def f(x, y, z): return 0.784 * torch.einsum("ij,jk,kl->il", x, y, z) x, y, z = torch.randn(2, 100), torch.randn(100, 2), torch.randn(2, 100) # note that the smallest here is y g = torch.fx.symbolic_trace(f) print("old graph\n", g.graph) g = optimize_einsums_full(g, (x, y, z)) print("new graph\n", g.graph) # optimal placement is on the 2x2 intermediate assert list(g.graph.nodes)[4].target == operator.mul out_truth = f(x, y, z) out_fused = g(x, y, z) assert allclose(out_fused, out_truth)