Merge pull request 'zzc' (#3) from zzc into main

Reviewed-on: #3
main
zhangzhichao 1 month ago
commit 12aa1c2480

1
.gitignore vendored

@ -1,5 +1,6 @@
venv
*.pdf
*.PDF
.vscode
visual_images/*.jpg
__pycache__

@ -1,7 +1,8 @@
from typing import List
import cv2
from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
from .utils import scanning_document_classify, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
from tqdm import tqdm
from ..image_helper import text_rec
class LayoutRecognitionResult(object):
@ -60,18 +61,18 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
# 扫描件
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
elif layout.clsid == 4:
# table
if scanning_document_classify(layout_img):
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
else:
content = table_rec(layout_img)
elif layout.clsid == 5:
# table caption
ocr_results = text_rec(layout_img)
_, ocr_results, _ = text_rec(layout_img)
content = ''
for o in ocr_results:
content += f'{o}\n'
@ -81,25 +82,26 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
if not content:
continue
content = content.replace('\\', '')
result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
outputs.append(result)
if is_scanning_document and len(outputs) == 1:
# 扫描件额外提取标题
h, w = source_page_no_watermark_img.shape[:2]
h, w = source_page_unwatermarked_img.shape[:2]
if h > w:
title_img = source_page_no_watermark_img[:360, :w, ...]
title_img = source_page_unwatermarked_img[:360, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
# vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
else:
title_img = source_page_no_watermark_img[:410, :w, ...]
title_img = source_page_unwatermarked_img[:410, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
# vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
title = text_rec(title_img)
_, title, _ = text_rec(title_img)
outputs[0].table_title = '\n'.join(title)
else:
# 自动给表格分配距离它最近的标题

@ -199,21 +199,19 @@ def parse_args(arg_list: Optional[List[str]] = None):
return args
try:
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Please install the rapidocr by pip install rapidocr"
) from exc
# try:
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
# except ModuleNotFoundError as exc:
# raise ModuleNotFoundError(
# "Please install the rapidocr by pip install rapidocr"
# ) from exc
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
table_engine = RapidTable(input_args)
def table2md_pipeline(img):
rapid_ocr_output = ocr_engine(img)
ocr_result = list(
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
)
def table2md_pipeline(img, ocr_result):
# rapid_ocr_output = ocr_engine(img)
# ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores))
table_results = table_engine(img, ocr_result)
html_content = table_results.pred_html
md_content = md(html_content)

@ -1,18 +0,0 @@
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
# proc
## Create Dataset Instance
input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg"
ds = read_local_images(input_file)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
content = md(html)
content = re.sub(r'\\([#*_`])', r'\1', content)
print(content)

@ -2,7 +2,6 @@ import os
import tempfile
import cv2
import numpy as np
from paddleocr import PaddleOCR
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
@ -11,6 +10,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
from ..image_helper import text_rec
def scanning_document_classify(image):
@ -66,46 +66,27 @@ def markdown_rec(image):
return html2md(html)
ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
def text_rec(image):
result = ocr.ocr(image, cls=True)
output = []
for idx in range(len(result)):
res = result[idx]
if not res:
continue
for line in res:
if not line:
continue
output.append(line[1][0])
return output
def table_rec(image):
return table2md_pipeline(image)
boxes, texts, conficences = text_rec(image)
ocr_result = list(zip(boxes, texts, conficences))
return table2md_pipeline(image, ocr_result)
table_converter = TableConverter(artifact_dict=create_model_dict())
def scanning_document_rec(image):
# TODO 内部的ocr可以替换为paddleocr以提升文字识别精度
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
tmp_image_path = f'{tempfile.mktemp()}.jpg'
try:
no_watermark_image = remove_watermark(cv2.imread(image_path))
prefix, suffix = image_path.split('.')
new_image_path = f'{prefix}_remove_watermark.{suffix}'
cv2.imwrite(new_image_path, no_watermark_image)
unwatermarked_image = remove_watermark(image)
cv2.imwrite(tmp_image_path, unwatermarked_image)
rendered = table_converter(new_image_path)
rendered = table_converter(tmp_image_path)
text, _, _ = text_from_rendered(rendered)
finally:
os.remove(image_path)
return text, no_watermark_image
os.remove(tmp_image_path)
return text, unwatermarked_image
def compute_box_distance(box1, box2):

@ -22,7 +22,7 @@ def create_connection():
return conn
except OperationalError as e:
logger.error(f"连接数据库失败: {e}")
return None
raise e
# 插入数据的函数

@ -4,6 +4,7 @@ import os
import paddleclas
import cv2
from .page_detection.utils import PageDetectionResult
from paddleocr import PaddleOCR
paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation")
@ -45,3 +46,26 @@ def page_detection_visual(page_detection_result: PageDetectionResult):
img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2)
cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2)
return img
ocr = PaddleOCR(use_angle_cls=False, lang='ch', use_gpu=True, show_log=False)
def text_rec(image):
result = ocr.ocr(image, cls=False)
boxes = []
texts = []
conficences = []
for idx in range(len(result)):
res = result[idx]
if not res:
continue
for line in res:
if not line:
continue
box = line[0]
text = line[1][0]
confidence = line[1][1]
boxes.append(box)
texts.append(text)
conficences.append(confidence)
return boxes, texts, conficences

@ -4,7 +4,6 @@ from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, Pa
from tqdm import tqdm
"""
0 - Text
1 - Title

@ -31,7 +31,7 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from picodet_postprocess import PicoDetPostProcess
from .picodet_postprocess import PicoDetPostProcess
from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid
@ -254,7 +254,6 @@ class Detector(object):
self.pred_config.labels,
output_dir=self.output_dir,
threshold=self.threshold)
# TODO 在这里处理batch
results.append(result)
results = self.merge_batch_result(results)
boxes = results['boxes']

179
marker/.gitignore vendored

@ -0,0 +1,179 @@
private.py
.DS_Store
local.env
experiments
test_data
training
wandb
*.dat
report.json
benchmark_data
debug_data
temp.md
temp
conversion_results
uploads
/cache
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.vscode/

@ -0,0 +1,12 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.10
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]

@ -0,0 +1,24 @@
Marker Contributor Agreement
This Marker Contributor Agreement ("MCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below.
If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement.
1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution:
- you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers;
- you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work;
- you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees;
- you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and
- you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution.
3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to:
- make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and
- at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements.
If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed.
4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license.
5. You covenant, represent, warrant and agree that:
- each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this MCA;
- to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and
- each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws.
You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the MCA.
6. This MCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply.

@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
Marker pdf to markdown converter
Copyright (C) 2023 Vikas Paruchuri
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
Marker Copyright (C) 2023 Vikas Paruchuri
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

@ -0,0 +1,511 @@
# Marker
Marker converts documents to markdown, JSON, and HTML quickly and accurately.
- Converts PDF, image, PPTX, DOCX, XLSX, HTML, EPUB files in all languages
- Formats tables, forms, equations, inline math, links, references, and code blocks
- Extracts and saves images
- Removes headers/footers/other artifacts
- Extensible with your own formatting and logic
- Optionally boost accuracy with LLMs
- Works on GPU, CPU, or MPS
## Performance
<img src="data/images/overall.png" width="800px"/>
Marker benchmarks favorably compared to cloud services like Llamaparse and Mathpix, as well as other open source tools.
The above results are running single PDF pages serially. Marker is significantly faster when running in batch mode, with a projected throughput of 122 pages/second on an H100 (.18 seconds per page across 22 processes).
See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instructions on how to run your own benchmarks.
## Hybrid Mode
For the highest accuracy, pass the `--use_llm` flag to use an LLM alongside marker. This will do things like merge tables across pages, handle inline math, format tables properly, and extract values from forms. It can use any gemini or ollama model. By default, it uses `gemini-2.0-flash`. See [below](#llm-services) for details.
Here is a table benchmark comparing marker, gemini flash alone, and marker with use_llm:
<img src="data/images/table.png" width="400px"/>
As you can see, the use_llm mode offers higher accuracy than marker or gemini alone.
## Examples
| PDF | File type | Markdown | JSON |
|-----|-----------|------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------|
| [Think Python](https://greenteapress.com/thinkpython/thinkpython.pdf) | Textbook | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/markdown/thinkpython/thinkpython.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/json/thinkpython.json) |
| [Switch Transformers](https://arxiv.org/pdf/2101.03961.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/markdown/switch_transformers/switch_trans.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/json/switch_trans.json) |
| [Multi-column CNN](https://arxiv.org/pdf/1804.07821.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/markdown/multicolcnn/multicolcnn.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/json/multicolcnn.json) |
# Commercial usage
I want marker to be as widely accessible as possible, while still funding my development/training costs. Research and personal usage is always okay, but there are some restrictions on commercial usage.
The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under \$5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to).
# Hosted API
There's a hosted API for marker available [here](https://www.datalab.to/):
- Supports PDFs, word documents, and powerpoints
- 1/4th the price of leading cloud-based competitors
- High uptime (99.99%), quality, and speed (around 15 seconds to convert a 250 page PDF)
# Community
[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.
# Installation
You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details.
Install with:
```shell
pip install marker-pdf
```
If you want to use marker on documents other than PDFs, you will need to install additional dependencies with:
```shell
pip install marker-pdf[full]
```
# Usage
First, some configuration:
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
- Some PDFs, even digital ones, have bad text in them. Set the `force_ocr` flag to ensure your PDF runs through OCR, or the `strip_existing_ocr` to keep all digital text, and strip out any existing OCR text.
## Interactive App
I've included a streamlit app that lets you interactively try marker with some basic options. Run it with:
```shell
pip install streamlit
marker_gui
```
## Convert a single file
```shell
marker_single /path/to/file.pdf
```
You can pass in PDFs or images.
Options:
- `--output_dir PATH`: Directory where output files will be saved. Defaults to the value specified in settings.OUTPUT_DIR.
- `--output_format [markdown|json|html]`: Specify the format for the output results.
- `--paginate_output`: Paginates the output, using `\n\n{PAGE_NUMBER}` followed by `-` * 48, then `\n\n`
- `--use_llm`: Uses an LLM to improve accuracy. You must set your Gemini API key using the `GOOGLE_API_KEY` env var.
- `--redo_inline_math`: If you want the highest quality inline math conversion, use this along with `--use_llm`.
- `--disable_image_extraction`: Don't extract images from the PDF. If you also specify `--use_llm`, then images will be replaced with a description.
- `--page_range TEXT`: Specify which pages to process. Accepts comma-separated page numbers and ranges. Example: `--page_range "0,5-10,20"` will process pages 0, 5 through 10, and page 20.
- `--force_ocr`: Force OCR processing on the entire document, even for pages that might contain extractable text.
- `--strip_existing_ocr`: Remove all existing OCR text in the document and re-OCR with surya.
- `--debug`: Enable debug mode for additional logging and diagnostic information.
- `--processors TEXT`: Override the default processors by providing their full module paths, separated by commas. Example: `--processors "module1.processor1,module2.processor2"`
- `--config_json PATH`: Path to a JSON configuration file containing additional settings.
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
- `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
- `--converter_cls`: One of `marker.converters.pdf.PdfConverter` (default) or `marker.converters.table.TableConverter`. The `PdfConverter` will convert the whole PDF, the `TableConverter` will only extract and convert tables.
- `--llm_service`: Which llm service to use if `--use_llm` is passed. This defaults to `marker.services.gemini.GoogleGeminiService`.
- `--help`: see all of the flags that can be passed into marker. (it supports many more options then are listed above)
The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/recognition/languages.py). If you don't need OCR, marker can work with any language.
## Convert multiple files
```shell
marker /path/to/input/folder --workers 4
```
- `marker` supports all the same options from `marker_single` above.
- `--workers` is the number of conversion workers to run simultaneously. This is set to 5 by default, but you can increase it to increase throughput, at the cost of more CPU/GPU usage. Marker will use 5GB of VRAM per worker at the peak, and 3.5GB average.
## Convert multiple files on multiple GPUs
```shell
NUM_DEVICES=4 NUM_WORKERS=15 marker_chunk_convert ../pdf_in ../md_out
```
- `NUM_DEVICES` is the number of GPUs to use. Should be `2` or greater.
- `NUM_WORKERS` is the number of parallel processes to run on each GPU.
## Use from python
See the `PdfConverter` class at `marker/converters/pdf.py` function for additional arguments that can be passed.
```python
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
converter = PdfConverter(
artifact_dict=create_model_dict(),
)
rendered = converter("FILEPATH")
text, _, images = text_from_rendered(rendered)
```
`rendered` will be a pydantic basemodel with different properties depending on the output type requested. With markdown output (default), you'll have the properties `markdown`, `metadata`, and `images`. For json output, you'll have `children`, `block_type`, and `metadata`.
### Custom configuration
You can pass configuration using the `ConfigParser`. To see all available options, do `marker_single --help`.
```python
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.config.parser import ConfigParser
config = {
"output_format": "json",
"ADDITIONAL_KEY": "VALUE"
}
config_parser = ConfigParser(config)
converter = PdfConverter(
config=config_parser.generate_config_dict(),
artifact_dict=create_model_dict(),
processor_list=config_parser.get_processors(),
renderer=config_parser.get_renderer(),
llm_service=config_parser.get_llm_service()
)
rendered = converter("FILEPATH")
```
### Extract blocks
Each document consists of one or more pages. Pages contain blocks, which can themselves contain other blocks. It's possible to programmatically manipulate these blocks.
Here's an example of extracting all forms from a document:
```python
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.schema import BlockTypes
converter = PdfConverter(
artifact_dict=create_model_dict(),
)
document = converter.build_document("FILEPATH")
forms = document.contained_blocks((BlockTypes.Form,))
```
Look at the processors for more examples of extracting and manipulating blocks.
## Other converters
You can also use other converters that define different conversion pipelines:
### Extract tables
The `TableConverter` will only convert and extract tables:
```python
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
converter = TableConverter(
artifact_dict=create_model_dict(),
)
rendered = converter("FILEPATH")
text, _, images = text_from_rendered(rendered)
```
This takes all the same configuration as the PdfConverter. You can specify the configuration `force_layout_block=Table` to avoid layout detection and instead assume every page is a table. Set `output_format=json` to also get cell bounding boxes.
You can also run this via the CLI with
```shell
marker_single FILENAME --use_llm --force_layout_block Table --converter_cls marker.converters.table.TableConverter --output_format json
```
# Output Formats
## Markdown
Markdown output will include:
- image links (images will be saved in the same folder)
- formatted tables
- embedded LaTeX equations (fenced with `$$`)
- Code is fenced with triple backticks
- Superscripts for footnotes
## HTML
HTML output is similar to markdown output:
- Images are included via `img` tags
- equations are fenced with `<math>` tags
- code is in `pre` tags
## JSON
JSON output will be organized in a tree-like structure, with the leaf nodes being blocks. Examples of leaf nodes are a single list item, a paragraph of text, or an image.
The output will be a list, with each list item representing a page. Each page is considered a block in the internal marker schema. There are different types of blocks to represent different elements.
Pages have the keys:
- `id` - unique id for the block.
- `block_type` - the type of block. The possible block types can be seen in `marker/schema/__init__.py`. As of this writing, they are ["Line", "Span", "FigureGroup", "TableGroup", "ListGroup", "PictureGroup", "Page", "Caption", "Code", "Figure", "Footnote", "Form", "Equation", "Handwriting", "TextInlineMath", "ListItem", "PageFooter", "PageHeader", "Picture", "SectionHeader", "Table", "Text", "TableOfContents", "Document"]
- `html` - the HTML for the page. Note that this will have recursive references to children. The `content-ref` tags must be replaced with the child content if you want the full html. You can see an example of this at `marker/output.py:json_to_html`. That function will take in a single block from the json output, and turn it into HTML.
- `polygon` - the 4-corner polygon of the page, in (x1,y1), (x2,y2), (x3, y3), (x4, y4) format. (x1,y1) is the top left, and coordinates go clockwise.
- `children` - the child blocks.
The child blocks have two additional keys:
- `section_hierarchy` - indicates the sections that the block is part of. `1` indicates an h1 tag, `2` an h2, and so on.
- `images` - base64 encoded images. The key will be the block id, and the data will be the encoded image.
Note that child blocks of pages can have their own children as well (a tree structure).
```json
{
"id": "/page/10/Page/366",
"block_type": "Page",
"html": "<content-ref src='/page/10/SectionHeader/0'></content-ref><content-ref src='/page/10/SectionHeader/1'></content-ref><content-ref src='/page/10/Text/2'></content-ref><content-ref src='/page/10/Text/3'></content-ref><content-ref src='/page/10/Figure/4'></content-ref><content-ref src='/page/10/SectionHeader/5'></content-ref><content-ref src='/page/10/SectionHeader/6'></content-ref><content-ref src='/page/10/TextInlineMath/7'></content-ref><content-ref src='/page/10/TextInlineMath/8'></content-ref><content-ref src='/page/10/Table/9'></content-ref><content-ref src='/page/10/SectionHeader/10'></content-ref><content-ref src='/page/10/Text/11'></content-ref>",
"polygon": [[0.0, 0.0], [612.0, 0.0], [612.0, 792.0], [0.0, 792.0]],
"children": [
{
"id": "/page/10/SectionHeader/0",
"block_type": "SectionHeader",
"html": "<h1>Supplementary Material for <i>Subspace Adversarial Training</i> </h1>",
"polygon": [
[217.845703125, 80.630859375], [374.73046875, 80.630859375],
[374.73046875, 107.0],
[217.845703125, 107.0]
],
"children": null,
"section_hierarchy": {
"1": "/page/10/SectionHeader/1"
},
"images": {}
},
...
]
}
```
## Metadata
All output formats will return a metadata dictionary, with the following fields:
```json
{
"table_of_contents": [
{
"title": "Introduction",
"heading_level": 1,
"page_id": 0,
"polygon": [...]
}
], // computed PDF table of contents
"page_stats": [
{
"page_id": 0,
"text_extraction_method": "pdftext",
"block_counts": [("Span", 200), ...]
},
...
]
}
```
# LLM Services
When running with the `--use_llm` flag, you have a choice of services you can use:
- `Gemini` - this will use the Gemini developer API by default. You'll need to pass `--gemini_api_key` to configuration.
- `Google Vertex` - this will use vertex, which can be more reliable. You'll need to pass `--vertex_project_id`. To use it, set `--llm_service=marker.services.vertex.GoogleVertexService`.
- `Ollama` - this will use local models. You can configure `--ollama_base_url` and `--ollama_model`. To use it, set `--llm_service=marker.services.ollama.OllamaService`.
- `Claude` - this will use the anthropic API. You can configure `--claude_api_key`, and `--claude_model_name`. To use it, set `--llm_service=marker.services.claude.ClaudeService`.
- `OpenAI` - this supports any openai-like endpoint. You can configure `--openai_api_key`, `--openai_model`, and `--openai_base_url`. To use it, set `--llm_service=marker.services.openai.OpenAIService`.
These services may have additional optional configuration as well - you can see it by viewing the classes.
# Internals
Marker is easy to extend. The core units of marker are:
- `Providers`, at `marker/providers`. These provide information from a source file, like a PDF.
- `Builders`, at `marker/builders`. These generate the initial document blocks and fill in text, using info from the providers.
- `Processors`, at `marker/processors`. These process specific blocks, for example the table formatter is a processor.
- `Renderers`, at `marker/renderers`. These use the blocks to render output.
- `Schema`, at `marker/schema`. The classes for all the block types.
- `Converters`, at `marker/converters`. They run the whole end to end pipeline.
To customize processing behavior, override the `processors`. To add new output formats, write a new `renderer`. For additional input formats, write a new `provider.`
Processors and renderers can be directly passed into the base `PDFConverter`, so you can specify your own custom processing easily.
## API server
There is a very simple API server you can run like this:
```shell
pip install -U uvicorn fastapi python-multipart
marker_server --port 8001
```
This will start a fastapi server that you can access at `localhost:8001`. You can go to `localhost:8001/docs` to see the endpoint options.
You can send requests like this:
```
import requests
import json
post_data = {
'filepath': 'FILEPATH',
# Add other params here
}
requests.post("http://localhost:8001/marker", data=json.dumps(post_data)).json()
```
Note that this is not a very robust API, and is only intended for small-scale use. If you want to use this server, but want a more robust conversion option, you can use the hosted [Datalab API](https://www.datalab.to/plans).
# Troubleshooting
There are some settings that you may find useful if things aren't working the way you expect:
- If you have issues with accuracy, try setting `--use_llm` to use an LLM to improve quality. You must set `GOOGLE_API_KEY` to a Gemini API key for this to work.
- Make sure to set `force_ocr` if you see garbled text - this will re-OCR the document.
- `TORCH_DEVICE` - set this to force marker to use a given torch device for inference.
- If you're getting out of memory errors, decrease worker count. You can also try splitting up long PDFs into multiple files.
## Debugging
Pass the `debug` option to activate debug mode. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.
# Benchmarks
## Overall PDF Conversion
We created a [benchmark set](https://huggingface.co/datasets/datalab-to/marker_benchmark) by extracting single PDF pages from common crawl. We scored based on a heuristic that aligns text with ground truth text segments, and an LLM as a judge scoring method.
| Method | Avg Time | Heuristic Score | LLM Score |
|------------|----------|-----------------|-----------|
| marker | 2.83837 | 95.6709 | 4.23916 |
| llamaparse | 23.348 | 84.2442 | 3.97619 |
| mathpix | 6.36223 | 86.4281 | 4.15626 |
| docling | 3.69949 | 86.7073 | 3.70429 |
Benchmarks were run on an H100 for markjer and docling - llamaparse and mathpix used their cloud services. We can also look at it by document type:
<img src="data/images/per_doc.png" width="1000px"/>
| Document Type | Marker heuristic | Marker LLM | Llamaparse Heuristic | Llamaparse LLM | Mathpix Heuristic | Mathpix LLM | Docling Heuristic | Docling LLM |
|----------------------|------------------|------------|----------------------|----------------|-------------------|-------------|-------------------|-------------|
| Scientific paper | 96.6737 | 4.34899 | 87.1651 | 3.96421 | 91.2267 | 4.46861 | 92.135 | 3.72422 |
| Book page | 97.1846 | 4.16168 | 90.9532 | 4.07186 | 93.8886 | 4.35329 | 90.0556 | 3.64671 |
| Other | 95.1632 | 4.25076 | 81.1385 | 4.01835 | 79.6231 | 4.00306 | 83.8223 | 3.76147 |
| Form | 88.0147 | 3.84663 | 66.3081 | 3.68712 | 64.7512 | 3.33129 | 68.3857 | 3.40491 |
| Presentation | 95.1562 | 4.13669 | 81.2261 | 4 | 83.6737 | 3.95683 | 84.8405 | 3.86331 |
| Financial document | 95.3697 | 4.39106 | 82.5812 | 4.16111 | 81.3115 | 4.05556 | 86.3882 | 3.8 |
| Letter | 98.4021 | 4.5 | 93.4477 | 4.28125 | 96.0383 | 4.45312 | 92.0952 | 4.09375 |
| Engineering document | 93.9244 | 4.04412 | 77.4854 | 3.72059 | 80.3319 | 3.88235 | 79.6807 | 3.42647 |
| Legal document | 96.689 | 4.27759 | 86.9769 | 3.87584 | 91.601 | 4.20805 | 87.8383 | 3.65552 |
| Newspaper page | 98.8733 | 4.25806 | 84.7492 | 3.90323 | 96.9963 | 4.45161 | 92.6496 | 3.51613 |
| Magazine page | 98.2145 | 4.38776 | 87.2902 | 3.97959 | 93.5934 | 4.16327 | 93.0892 | 4.02041 |
## Throughput
We benchmarked throughput using a [single long PDF](https://www.greenteapress.com/thinkpython/thinkpython.pdf).
| Method | Time per page | Time per document | VRAM used |
|---------|---------------|-------------------|---------- |
| marker | 0.18 | 43.42 | 3.17GB |
The projected throughput is 122 pages per second on an H100 - we can run 22 individual processes given the VRAM used.
## Table Conversion
Marker can extract tables from PDFs using `marker.converters.table.TableConverter`. The table extraction performance is measured by comparing the extracted HTML representation of tables against the original HTML representations using the test split of [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/). The HTML representations are compared using a tree edit distance based metric to judge both structure and content. Marker detects and identifies the structure of all tables in a PDF page and achieves these scores:
| Method | Avg score | Total tables |
|------------------|-----------|--------------|
| marker | 0.816 | 99 |
| marker w/use_llm | 0.907 | 99 |
| gemini | 0.829 | 99 |
The `--use_llm` flag can significantly improve table recognition performance, as you can see.
We filter out tables that we cannot align with the ground truth, since fintabnet and our layout model have slightly different detection methods (this results in some tables being split/merged).
## Running your own benchmarks
You can benchmark the performance of marker on your machine. Install marker manually with:
```shell
git clone https://github.com/VikParuchuri/marker.git
poetry install
```
### Overall PDF Conversion
Download the benchmark data [here](https://drive.google.com/file/d/1ZSeWDo2g1y0BRLT7KnbmytV2bjWARWba/view?usp=sharing) and unzip. Then run the overall benchmark like this:
```shell
python benchmarks/overall.py --methods marker --scores heuristic,llm
```
Options:
- `--use_llm` use an llm to improve the marker results.
- `--max_rows` how many rows to process for the benchmark.
- `--methods` can be `llamaparse`, `mathpix`, `docling`, `marker`. Comma separated.
- `--scores` which scoring functions to use, can be `llm`, `heuristic`. Comma separated.
### Table Conversion
The processed FinTabNet dataset is hosted [here](https://huggingface.co/datasets/datalab-to/fintabnet-test) and is automatically downloaded. Run the benchmark with:
```shell
python benchmarks/table/table.py --max_rows 100
```
Options:
- `--use_llm` uses an llm with marker to improve accuracy.
- `--use_gemini` also benchmarks gemini 2.0 flash.
# How it works
Marker is a pipeline of deep learning models:
- Extract text, OCR if necessary (heuristics, [surya](https://github.com/VikParuchuri/surya))
- Detect page layout and find reading order ([surya](https://github.com/VikParuchuri/surya))
- Clean and format each block (heuristics, [texify](https://github.com/VikParuchuri/texify), [surya](https://github.com/VikParuchuri/surya))
- Optionally use an LLM to improve quality
- Combine blocks and postprocess complete text
It only uses models where necessary, which improves speed and accuracy.
# Limitations
PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:
- Very complex layouts, with nested tables and forms, may not work
- Forms may not be rendered well
Note: Passing the `--use_llm` flag will mostly solve these issues.
# Thanks
This work would not have been possible without amazing open source models and datasets, including (but not limited to):
- Surya
- Texify
- Pypdfium2/pdfium
- DocLayNet from IBM
Thank you to the authors of these models and datasets for making them available to the community!

@ -0,0 +1,53 @@
import json
from typing import List
import datasets
from tqdm import tqdm
from benchmarks.overall.registry import METHOD_REGISTRY
from benchmarks.overall.schema import FullResult
def build_dataset(bench_dataset: datasets.Dataset, result: FullResult, score_types: List[str], max_rows: int | None = None) -> datasets.Dataset:
rows = []
for idx, sample in tqdm(enumerate(bench_dataset), desc="Building dataset"):
if idx not in result["markdown"]:
continue
if max_rows is not None and idx >= max_rows:
break
row = {
"uuid": sample["uuid"],
"classification": sample["classification"],
"language": sample["language"],
"img": sample["img"],
}
for method in result["markdown"][idx]:
if method == "gt":
continue
method_cls = METHOD_REGISTRY[method]()
md = result["markdown"][idx][method]
try:
method_img = method_cls.render(result["markdown"][idx][method])
except Exception as e:
# This can happen when the markdown is None
method_img = PIL.Image.new("RGB", (200, 200))
row[f"{method}_md"] = md
row[f"{method}_img"] = method_img
for score_type in score_types:
try:
row[f"{method}_{score_type}"] = result["scores"][idx][method][score_type]["score"]
except KeyError:
row[f"{method}_{score_type}"] = -1.0 # Missing score
try:
row[f"{method}_{score_type}_detail"] = json.dumps(result["scores"][idx][method][score_type]["specific_scores"])
except KeyError:
row[f"{method}_{score_type}_detail"] = "" # Missing detail
rows.append(row)
ds = datasets.Dataset.from_list(rows)
return ds

@ -0,0 +1,68 @@
from pathlib import Path
from typing import Dict, List
import tabulate
from benchmarks.overall.schema import FullResult
def write_table(title: str, rows: list, headers: list, out_path: Path, filename: str):
table = tabulate.tabulate(rows, headers=headers, tablefmt="github")
with open(out_path / filename, "w", encoding="utf-8") as f:
f.write(f"# {title}\n")
f.write(table)
print(title)
print(table)
def print_scores(result: FullResult, out_path: Path, methods: List[str], score_types: List[str], default_score_type="heuristic", default_method="marker"):
document_types = list(result["averages_by_type"][default_method][default_score_type].keys())
headers = ["Document Type"]
for method in methods:
for score_type in score_types:
headers.append(f"{method} {score_type}")
document_rows = [[k] for k in document_types]
for i, doc_type in enumerate(document_types):
for method in methods:
for score_type in score_types:
avg_score = sum(result["averages_by_type"][method][score_type][doc_type]) / max(1, len(result["averages_by_type"][method][score_type][doc_type]))
document_rows[i].append(avg_score)
write_table("Document Types", document_rows, headers, out_path, "document_types.md")
headers = ["Block Type"]
block_types = list(result["averages_by_block_type"][default_method][default_score_type].keys()) # all possible blocks
block_score_types = list(result["averages_by_block_type"][default_method].keys())
for method in methods:
for score_type in block_score_types:
headers.append(f"{method} {score_type}")
block_rows = [[k] for k in block_types]
for i, block_type in enumerate(block_types):
for method in methods:
for score_type in block_score_types:
avg_score = sum(result["averages_by_block_type"][method][score_type][block_type]) / max(1, len(result["averages_by_block_type"][method][score_type][block_type]))
block_rows[i].append(avg_score)
write_table("Block types", block_rows, headers, out_path, "block_types.md")
headers = ["Method", "Avg Time"] + score_types
inference_rows = [[k] for k in methods]
all_raw_scores = [result["scores"][i] for i in result["scores"]]
for i, method in enumerate(methods):
avg_time = sum(result["average_times"][method]) / max(1, len(result["average_times"][method]))
inference_rows[i].append(avg_time)
for score_type in score_types:
scores_lst = []
for ar in all_raw_scores:
try:
# Sometimes a few llm scores are missing
scores_lst.append(ar[method][score_type]["score"])
except KeyError:
continue
avg_score = sum(scores_lst) / max(1, len(scores_lst))
inference_rows[i].append(avg_score)
write_table("Overall Results", inference_rows, headers, out_path, "overall.md")
print("Scores computed by aligning ground truth markdown blocks with predicted markdown for each method. The scores are 0-100 based on edit distance.")

@ -0,0 +1,63 @@
import json
from json import JSONDecodeError
from pathlib import Path
import datasets
from tqdm import tqdm
class Downloader:
cache_path: Path = Path("cache")
service: str
def __init__(self, api_key, app_id, max_rows: int = 2200):
self.cache_path.mkdir(exist_ok=True)
self.max_rows = max_rows
self.api_key = api_key
self.app_id = app_id
self.ds = datasets.load_dataset("datalab-to/marker_benchmark", split="train")
def get_html(self, pdf_bytes):
raise NotImplementedError
def upload_ds(self):
rows = []
for file in self.cache_path.glob("*.json"):
with open(file, "r") as f:
data = json.load(f)
rows.append(data)
out_ds = datasets.Dataset.from_list(rows, features=datasets.Features({
"md": datasets.Value("string"),
"uuid": datasets.Value("string"),
"time": datasets.Value("float"),
}))
out_ds.push_to_hub(f"datalab-to/marker_benchmark_{self.service}", private=True)
def generate_data(self):
max_rows = self.max_rows
for idx, sample in tqdm(enumerate(self.ds), desc=f"Saving {self.service} results"):
cache_file = self.cache_path / f"{idx}.json"
if cache_file.exists():
continue
pdf_bytes = sample["pdf"] # This is a single page PDF
try:
out_data = self.get_html(pdf_bytes)
except JSONDecodeError as e:
print(f"Error with sample {idx}: {e}")
continue
except Exception as e:
print(f"Error with sample {idx}: {e}")
continue
out_data["uuid"] = sample["uuid"]
with cache_file.open("w") as f:
json.dump(out_data, f)
if idx >= max_rows:
break
def __call__(self):
self.generate_data()
self.upload_ds()

@ -0,0 +1,63 @@
import io
import time
import requests
from benchmarks.overall.download.base import Downloader
class LlamaParseDownloader(Downloader):
service = "llamaparse"
def get_html(self, pdf_bytes):
rand_name = str(time.time()) + ".pdf"
start = time.time()
buff = io.BytesIO(pdf_bytes)
md = upload_and_parse_file(self.api_key, rand_name, buff)
end = time.time()
if isinstance(md, bytes):
md = md.decode("utf-8")
return {
"md": md,
"time": end - start,
}
def upload_and_parse_file(api_key: str, fname: str, buff, max_retries: int = 180, delay: int = 1):
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
# Upload file
files = {
'file': (fname, buff, 'application/pdf')
}
response = requests.post(
'https://api.cloud.llamaindex.ai/api/v1/parsing/upload',
headers=headers,
files=files
)
response.raise_for_status()
job_id = response.json()['id']
# Poll for completion
for _ in range(max_retries):
status_response = requests.get(
f'https://api.cloud.llamaindex.ai/api/v1/parsing/job/{job_id}',
headers=headers
)
status_response.raise_for_status()
if status_response.json()['status'] == 'SUCCESS':
# Get results
result_response = requests.get(
f'https://api.cloud.llamaindex.ai/api/v1/parsing/job/{job_id}/result/markdown',
headers=headers
)
result_response.raise_for_status()
return result_response.json()['markdown']
time.sleep(delay)
raise TimeoutError("Job did not complete within the maximum retry attempts")

@ -0,0 +1,25 @@
import click
from benchmarks.overall.download.llamaparse import LlamaParseDownloader
from benchmarks.overall.download.mathpix import MathpixDownloader
from benchmarks.overall.download.mistral import MistralDownloader
@click.command("Download data from inference services")
@click.argument("service", type=click.Choice(["mathpix", "llamaparse", "mistral"]))
@click.option("--max_rows", type=int, default=2200)
@click.option("--api_key", type=str, default=None)
@click.option("--app_id", type=str, default=None)
def main(service: str, max_rows: int, api_key: str, app_id: str):
registry = {
"mathpix": MathpixDownloader,
"llamaparse": LlamaParseDownloader,
"mistral": MistralDownloader,
}
downloader = registry[service](api_key, app_id, max_rows=max_rows)
# Generate data and upload to hub
downloader()
if __name__ == "__main__":
main()

@ -0,0 +1,80 @@
import json
import time
import requests
from benchmarks.overall.download.base import Downloader
class MathpixDownloader(Downloader):
service = "mathpix"
def get_html(self, pdf_bytes):
headers = {
"app_id": self.app_id,
"app_key": self.api_key,
}
start = time.time()
pdf_id = mathpix_request(pdf_bytes, headers)
status = mathpix_status(pdf_id, headers)
if status in ["processing", "error"]:
md = ""
else:
md = mathpix_results(pdf_id, headers)
end = time.time()
if isinstance(md, bytes):
md = md.decode("utf-8")
return {
"md": md,
"time": end - start
}
def mathpix_request(buffer, headers):
response = requests.post("https://api.mathpix.com/v3/pdf",
headers=headers,
data={
"options_json": json.dumps(
{
"conversion_formats": {
"md": True,
"html": True
}
}
)
},
files={
"file": buffer
}
)
data = response.json()
pdf_id = data["pdf_id"]
return pdf_id
def mathpix_status(pdf_id, headers):
max_iters = 120
i = 0
status = "processing"
status2 = "processing"
while i < max_iters:
time.sleep(1)
response = requests.get(f"https://api.mathpix.com/v3/converter/{pdf_id}",
headers=headers
)
status_resp = response.json()
if "conversion_status" not in status_resp:
continue
status = status_resp["conversion_status"]["md"]["status"]
status2 = status_resp["conversion_status"]["html"]["status"]
if status == "completed" and status2 == "completed":
break
elif status == "error" or status2 == "error":
break
out_status = "completed" if status == "completed" and status2 == "completed" else "error"
return out_status
def mathpix_results(pdf_id, headers, ext="md"):
response = requests.get(f"https://api.mathpix.com/v3/converter/{pdf_id}.{ext}",
headers=headers
)
return response.content

@ -0,0 +1,73 @@
import io
import time
import requests
from benchmarks.overall.download.base import Downloader
class MistralDownloader(Downloader):
service = "mistral"
def get_html(self, pdf_bytes):
rand_name = str(time.time()) + ".pdf"
start = time.time()
buff = io.BytesIO(pdf_bytes)
md = upload_and_process_file(self.api_key, rand_name, buff)
end = time.time()
if isinstance(md, bytes):
md = md.decode("utf-8")
return {
"md": md,
"time": end - start,
}
def upload_and_process_file(api_key: str, fname: str, buff):
headers = {
"Authorization": f"Bearer {api_key}"
}
upload_headers = headers.copy()
files = {
'file': (fname, buff, 'application/pdf'),
'purpose': (None, 'ocr')
}
upload_response = requests.post(
'https://api.mistral.ai/v1/files',
headers=upload_headers,
files=files
)
upload_response.raise_for_status()
file_id = upload_response.json()['id']
url_headers = headers.copy()
url_headers["Accept"] = "application/json"
url_response = requests.get(
f'https://api.mistral.ai/v1/files/{file_id}/url?expiry=24',
headers=url_headers
)
url_response.raise_for_status()
signed_url = url_response.json()['url']
ocr_headers = headers.copy()
ocr_headers["Content-Type"] = "application/json"
ocr_data = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url
},
"include_image_base64": True
}
ocr_response = requests.post(
'https://api.mistral.ai/v1/ocr',
headers=ocr_headers,
json=ocr_data
)
ocr_response.raise_for_status()
result = ocr_response.json()
return result["pages"][0]["markdown"]

@ -0,0 +1,221 @@
import json
import random
import time
import os
from dataclasses import dataclass
from typing import List, Dict, Tuple, Literal
from PIL import Image
from collections import defaultdict
import tabulate
import click
import datasets
from google import genai
from google.genai.errors import APIError
from pydantic import BaseModel
from tqdm import tqdm
from marker.settings import settings
rating_prompt = """
You're a document analysis expert who is comparing two different markdown samples to an image to see which one represents the content of the image better. The markdown will be called version A and version B.
Here are some notes on the image and markdown:
- Some parts of the page may have been recognized as images and linked from the markdown, like `![](_page_0_Picture_0.jpeg)`.
- Tables will be formatted as Github flavored markdown.
- Block equations will be in LaTeX.
- The image and markdown may be in any language.
- The markdown is based on the text extracted from the document, and sometimes the document may have had bad OCR applied to it, resulting in gibberish text.
The markdown should fully capture the meaning and formatting of the text in the image. You'll evaluate the markdown based on the image provided.
**Instructions**
Follow this process to evaluate the markdown:
1. Carefully examine the image.
2. Carefully examine the first markdown input provided.
3. Describe how well version a represents the image.
4. Carefully examine the second markdown input provided.
5. Describe how well version B represents the image.
6. Compare version A and version B.
7. Decide which markdown representation is better, based on the criteria below. Output version_a if version a is better, and version_b if version b is better.
Use these criteria when judging the markdown:
- Overall - the overall quality of the markdown as compared to the image.
- Text quality - the quality of the text extraction from the image.
- Formatting quality - the quality of the formatting applied to the markdown, as compared to the image.
- Tables - how effectively the tables have been extracted and formatted.
- Forms - how effectively the forms have extracted and formatted.
- Equations - how effectively block equations have been converted to LaTeX.
- Lists - if the lists have been properly extracted and formatted.
- Images - if images are identified and placed correctly.
Notes on scoring:
- Perfect markdown will include all of the important text from the image, and the formatting will be correct (minor mistakes okay). It's okay to omit some text that isn't important to the meaning, like page numbers and chapter headings. If the entire page is an image, it's okay if the markdown is just a link to the image, unless the image would be better represented as text.
- Bad markdown will have major missing text segments from the markdown or completely unreadable formatting. It may also have key values that are different from the values in the image.
Output json, like in the example below.
**Example**
Version A
```markdown
# *Section 1*
This is some *markdown* extracted from a document. Here is a block equation:
$$\frac{ab \cdot x^5 + x^2 + 2 \cdot x + 123}{t}$$
```
Version B
```markdown
# Section 1
This is some markdown extracted from a document. Here is a block equation:
$$\frac{ab \cdot x^5 + x^2 + 2 \cdot x + 124}{t}$$
```
Output
```json
{
"image_description": "In the image, there is a section header 'Section 1', followed by some text and a block equation.",
"version_a_description": "In the markdown, there is a section header 'Section 1', followed by some text and a block equation.",
"version_b_description": "In the markdown, there is a section header 'Section 1', followed by some text and a block equation. The formatting in version b is slightly different from the image. The value 124 is also different from the image.",
"comparison": "Version A is better than version B. The text and formatting in version A matches the image better than version B. Version B also has an incorrect value.",
"winner": "version_a",
}
```
**Input**
Version A
```markdown
{{version_a}}
```
Version B
```markdown
{{version_b}}
```
**Output**
"""
class ComparerSchema(BaseModel):
image_description: str
version_a_description: str
version_b_description: str
comparison: str
winner: Literal["version_a", "version_b"]
class Comparer:
def __init__(self):
pass
def __call__(
self,
img: Image.Image,
version_a: str,
version_b: str
) -> str | None:
if version_a is None and version_b is not None:
return "version_b"
elif version_b is None and version_a is not None:
return "version_a"
hydrated_prompt = rating_prompt.replace("{{version_a}}", version_a).replace("{{version_b}}", version_b)
try:
rating = self.llm_rater(img, hydrated_prompt)
except Exception as e:
print(f"Error: {e}")
return
return rating
def llm_rater(self, img: Image.Image, prompt: str):
response = self.llm_response_wrapper(
[img, prompt],
ComparerSchema
)
assert "winner" in response, f"Response missing 'winner' key: {response}"
return response["winner"]
def llm_response_wrapper(
self,
prompt,
response_schema,
):
client = genai.Client(
http_options={"timeout": 60000},
vertexai=True,
project=os.getenv("VERTEX_PROJECT_ID"),
location=os.getenv("VERTEX_LOCATION"),
)
try:
responses = client.models.generate_content(
model="gemini-2.0-flash-001",
contents=prompt,
config={
"temperature": 0,
"response_schema": response_schema,
"response_mime_type": "application/json",
},
)
output = responses.candidates[0].content.parts[0].text
return json.loads(output)
except APIError as e:
print(f"Hit Gemini rate limit")
return
except Exception as e:
print(f"Error: {e}")
return
def display_win_rates_table(win_rates: dict):
table = []
headers = ["Method A", "Method B", "Wins", "Losses", "Win %"]
for method_a, method_b_dict in win_rates.items():
row = [method_a]
for method_b, results in method_b_dict.items():
row = [method_a, method_b, results["win"], results["loss"], (results["win"] / (results["win"] + results["loss"])) * 100]
table.append(row)
print(tabulate.tabulate(table, headers=headers, tablefmt="pretty"))
@click.command("Calculate win rates for document conversion methods")
@click.argument("dataset", type=str)
@click.option("--methods", type=str, help="List of methods to compare: comma separated like marker,mathpix")
@click.option("--row_samples", type=int, default=2, help="Number of samples per row")
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process")
def main(
dataset: str,
methods: str,
row_samples: int,
max_rows: int
):
ds = datasets.load_dataset(dataset, split="train")
method_lst = methods.split(",")
win_rates = {m: defaultdict(lambda: defaultdict(int)) for m in method_lst}
comparer = Comparer()
max_rows = max_rows or len(ds)
for i in tqdm(range(max_rows), desc="Calculating win rates..."):
row = ds[i]
# Avoid any bias in ordering
random.shuffle(method_lst)
for j, method_a in enumerate(method_lst[:-1]):
for z, method_b in enumerate(method_lst[j:]):
if method_a == method_b:
continue
method_a_md = row[f"{method_a}_md"]
method_b_md = row[f"{method_b}_md"]
winner = comparer(row["img"], method_a_md, method_b_md)
if not winner:
continue
if winner == "version_a":
win_rates[method_a][method_b]["win"] += 1
win_rates[method_b][method_a]["loss"] += 1
else:
win_rates[method_b][method_a]["win"] += 1
win_rates[method_a][method_b]["loss"] += 1
if i % 10 == 0:
display_win_rates_table(win_rates)
display_win_rates_table(win_rates)
if __name__ == "__main__":
main()

@ -0,0 +1,100 @@
import io
import random
import re
from typing import Tuple
import markdown2
from PIL import Image
from playwright.sync_api import sync_playwright
from benchmarks.overall.methods.schema import BenchmarkResult
from marker.renderers.markdown import MarkdownRenderer
class BaseMethod:
def __init__(self, **kwargs):
for kwarg in kwargs:
if hasattr(self, kwarg):
setattr(self, kwarg, kwargs[kwarg])
@staticmethod
def convert_to_md(html: str):
md = MarkdownRenderer()
markdown = md.md_cls.convert(html)
return markdown
def __call__(self, sample) -> BenchmarkResult:
raise NotImplementedError()
def render(self, markdown: str):
return self.html_to_image(self.convert_to_html(markdown))
@staticmethod
def convert_to_html(md: str):
block_placeholders = []
inline_placeholders = []
# Add placeholders for the math
def block_sub(match):
content = match.group(1)
placeholder = f"1BLOCKMATH{len(block_placeholders)}1"
block_placeholders.append((placeholder, f"$${content}$$"))
return placeholder
def inline_sub(match):
content = match.group(1)
placeholder = f"1INLINEMATH{len(inline_placeholders)}1"
inline_placeholders.append((placeholder, f"${content}$"))
return placeholder
md = re.sub(r'\${2}(.*?)\${2}', block_sub, md, flags=re.DOTALL)
md = re.sub(r'\$(.*?)\$', inline_sub, md)
html = markdown2.markdown(md, extras=['tables'])
# Replace placeholders
for placeholder, math_str in block_placeholders:
html = html.replace(placeholder, math_str)
for placeholder, math_str in inline_placeholders:
html = html.replace(placeholder, math_str)
return html
def html_to_image(self, html: str) -> Image.Image:
with sync_playwright() as p:
browser = p.chromium.launch()
page = browser.new_page()
html_str = f"""
<!DOCTYPE html>
<html>
<head>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/katex.min.css" integrity="sha384-zh0CIslj+VczCZtlzBcjt5ppRcsAmDnRem7ESsYwWwg3m/OaJ2l4x7YBZl9Kxxib" crossorigin="anonymous">
<!-- The loading of KaTeX is deferred to speed up page rendering -->
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/katex.min.js" integrity="sha384-Rma6DA2IPUwhNxmrB/7S3Tno0YY7sFu9WSYMCuulLhIqYSGZ2gKCJWIqhBWqMQfh" crossorigin="anonymous"></script>
<!-- To automatically render math in text elements, include the auto-render extension: -->
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/contrib/auto-render.min.js" integrity="sha384-hCXGrW6PitJEwbkoStFjeJxv+fSOOQKOPbJxSfM6G5sWZjAyWhXiTIIAmQqnlLlh" crossorigin="anonymous"></script>
</head>
<body>
{html}
<script>
document.addEventListener("DOMContentLoaded", function() {{
renderMathInElement(document.body, {{
delimiters: [
{{left: '$$', right: '$$', display: true}},
{{left: '$', right: '$', display: false}}
],
throwOnError : false
}});
}});
</script>
</body>
</html>
""".strip()
page.set_viewport_size({"width": 1200, "height": 800})
page.set_content(html_str)
page.wait_for_load_state("domcontentloaded")
page.wait_for_timeout(500) # Wait for KaTeX to render
screenshot_bytes = page.screenshot(full_page=True)
browser.close()
return Image.open(io.BytesIO(screenshot_bytes))

@ -0,0 +1,26 @@
import tempfile
import time
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
class DoclingMethod(BaseMethod):
model_dict: dict = None
use_llm: bool = False
def __call__(self, sample) -> BenchmarkResult:
from docling.document_converter import DocumentConverter
pdf_bytes = sample["pdf"] # This is a single page PDF
converter = DocumentConverter()
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f:
f.write(pdf_bytes)
start = time.time()
result = converter.convert(f.name)
total = time.time() - start
return {
"markdown": result.document.export_to_markdown(),
"time": total
}

@ -0,0 +1,29 @@
from typing import List
import json
from PIL import Image
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
class GTMethod(BaseMethod):
def __call__(self, sample) -> BenchmarkResult:
gt_blocks = json.loads(sample["gt_blocks"])
gt_html = [block["html"] for block in gt_blocks if len(block["html"]) > 0]
gt_markdown = [self.convert_to_md(block) for block in gt_html]
return {
"markdown": gt_markdown,
"time": 0
}
def render(self, html: List[str]) -> Image.Image:
joined = "\n\n".join(html)
html = f"""
<html>
<head></head>
<body>
{joined}
</body>
</html>
""".strip()
return self.html_to_image(html)

@ -0,0 +1,22 @@
import datasets
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
class LlamaParseMethod(BaseMethod):
llamaparse_ds: datasets.Dataset = None
def __call__(self, sample) -> BenchmarkResult:
uuid = sample["uuid"]
data = None
for row in self.llamaparse_ds:
if str(row["uuid"]) == str(uuid):
data = row
break
if not data:
raise ValueError(f"Could not find data for uuid {uuid}")
return {
"markdown": data["md"],
"time": data["time"]
}

@ -0,0 +1,41 @@
import os
import tempfile
import time
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
from marker.config.parser import ConfigParser
from marker.converters.pdf import PdfConverter
class MarkerMethod(BaseMethod):
model_dict: dict = None
use_llm: bool = False
def __call__(self, sample) -> BenchmarkResult:
pdf_bytes = sample["pdf"] # This is a single page PDF
parser = ConfigParser({
"page_range": "0",
"disable_tqdm": True,
"use_llm": self.use_llm,
"redo_inline_math": self.use_llm,
"llm_service": "marker.services.vertex.GoogleVertexService",
"vertex_project_id": os.getenv("VERTEX_PROJECT_ID"),
})
block_converter = PdfConverter(
artifact_dict=self.model_dict,
config=parser.generate_config_dict(),
llm_service=parser.get_llm_service()
)
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f:
f.write(pdf_bytes)
start = time.time()
rendered = block_converter(f.name)
total = time.time() - start
return {
"markdown": rendered.markdown,
"time": total
}

@ -0,0 +1,22 @@
import datasets
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
class MathpixMethod(BaseMethod):
mathpix_ds: datasets.Dataset = None
def __call__(self, sample) -> BenchmarkResult:
uuid = sample["uuid"]
data = None
for row in self.mathpix_ds:
if str(row["uuid"]) == str(uuid):
data = row
break
if not data:
raise ValueError(f"Could not find data for uuid {uuid}")
return {
"markdown": data["md"],
"time": data["time"]
}

@ -0,0 +1,22 @@
import datasets
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
class MistralMethod(BaseMethod):
mistral_ds: datasets.Dataset = None
def __call__(self, sample) -> BenchmarkResult:
uuid = sample["uuid"]
data = None
for row in self.mistral_ds:
if str(row["uuid"]) == str(uuid):
data = row
break
if not data:
raise ValueError(f"Could not find data for uuid {uuid}")
return {
"markdown": data["md"],
"time": data["time"]
}

@ -0,0 +1,91 @@
import base64
import json
import tempfile
import time
from io import BytesIO
import torch
from PIL import Image
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
def convert_single_page(filename: str, model, processor, device):
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
image_base64 = render_pdf_to_base64png(filename, 1, target_longest_image_dim=1024)
# Build the prompt, using document metadata
anchor_text = get_anchor_text(filename, 1, pdf_engine="pdfreport", target_length=4000)
prompt = build_finetuning_prompt(anchor_text)
# Build the full prompt
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
# Apply the chat template and processor
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = {key: value.to(device) for (key, value) in inputs.items()}
# Generate the output
output = model.generate(
**inputs,
temperature=0.8,
max_new_tokens=8192,
num_return_sequences=1,
do_sample=True,
)
# Decode the output
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = processor.tokenizer.batch_decode(
new_tokens, skip_special_tokens=True
)[0]
try:
text_output = json.loads(text_output)
text = text_output["natural_text"]
except Exception:
try:
text = text_output.split("natural_text")[1].strip()
except Exception:
text = ""
return text
class OlmOCRMethod(BaseMethod):
olmocr_model: dict = None
use_llm: bool = False
def __call__(self, sample) -> BenchmarkResult:
pdf_bytes = sample["pdf"] # This is a single page PDF
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f:
f.write(pdf_bytes)
start = time.time()
result = convert_single_page(f.name, self.olmocr_model["model"], self.olmocr_model["processor"], self.olmocr_model["model"].device)
total = time.time() - start
return {
"markdown": result,
"time": total
}

@ -0,0 +1,6 @@
from typing import TypedDict, List
class BenchmarkResult(TypedDict):
markdown: str | List[str]
time: float | None

@ -0,0 +1,178 @@
import json
import os
import traceback
from collections import defaultdict
from pathlib import Path
from typing import List
import click
import datasets
import torch
from tqdm import tqdm
from benchmarks.overall.display.dataset import build_dataset
from benchmarks.overall.registry import SCORE_REGISTRY, METHOD_REGISTRY
from benchmarks.overall.schema import FullResult
from marker.logger import configure_logging
from marker.models import create_model_dict
from marker.settings import settings
from benchmarks.overall.display.table import print_scores
configure_logging()
def get_method_scores(benchmark_dataset: datasets.Dataset, methods: List[str], score_types: List[str], artifacts: dict, max_rows=None) -> FullResult:
bench_scores = {}
averages_by_type = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
averages_by_block_type = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
average_times = defaultdict(list)
markdown_by_method = defaultdict(dict)
total_rows = len(benchmark_dataset)
if max_rows:
total_rows = min(max_rows, total_rows)
for idx, sample in tqdm(enumerate(benchmark_dataset), desc="Running benchmark", total=total_rows):
if max_rows is not None and idx >= max_rows:
break
doc_type = sample["classification"]
gt_cls = METHOD_REGISTRY["gt"]
gt_blocks = json.loads(sample["gt_blocks"])
gt_md = gt_cls(**artifacts)(sample)["markdown"]
markdown_by_method[idx]["gt"] = gt_md
out_data = defaultdict(dict)
try:
for method in methods:
method_cls = METHOD_REGISTRY[method](**artifacts)
method_info = method_cls(sample)
method_md = method_info["markdown"]
if method_md is None:
method_md = "" # Avoid None values
average_times[method].append(method_info["time"])
markdown_by_method[idx][method] = method_md
for score_type in score_types:
score_cls = SCORE_REGISTRY[score_type]()
try:
scores = score_cls(sample, gt_md, method_md)
except Exception as e:
# Some scorers can fail, like the LLM one
print(f"Failed to score {method} with {score_type}: {e}")
continue
out_data[method][score_type] = scores
averages_by_type[method][score_type][doc_type].append(scores["score"])
if "by_block" in scores["specific_scores"]: # Not all scorers support this
for score, gt_block in zip(scores["specific_scores"]["by_block"], gt_blocks):
averages_by_block_type[method][score_type][gt_block["block_type"]].append(score)
except Exception as e:
print(f"Failed to process {idx}: {e}")
traceback.print_exc()
if idx in markdown_by_method:
del markdown_by_method[idx]
continue
bench_scores[idx] = out_data
return {
"scores": bench_scores,
"markdown": markdown_by_method,
"averages_by_type": averages_by_type,
"averages_by_block_type": averages_by_block_type,
"average_times": average_times,
}
@click.command(help="Benchmark PDF to MD conversion.")
@click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark")
@click.option("--out_dataset", type=str, help="Path to the output dataset", default=None)
@click.option("--methods", type=str, help="Comma separated list of other methods to compare against. Possible values: marker,mathpix,llamaparse,docling,mistral", default="marker")
@click.option("--scores", type=str, help="Comma separated list of scoring functions to use. Possible values: heuristic,llm", default="heuristic")
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.")
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.")
@click.option("--use_llm", is_flag=True, help="Use the LLM model for better marker quality.")
@click.option("--languages", type=str, help="Comma separated list of languages to use for LLM", default=None)
def main(
dataset: str,
out_dataset: str,
methods: str,
scores: str,
result_path: str,
max_rows: int,
use_llm: bool,
languages: str
):
out_path = Path(result_path)
out_path.mkdir(parents=True, exist_ok=True)
methods = methods.split(",")
for method in methods:
if method not in METHOD_REGISTRY:
raise ValueError(f"Method {method} not allowed. Allowed methods are {METHOD_REGISTRY.keys()}")
# Ensure marker is always first
all_methods = list(set(methods))
methods = ["marker"] if "marker" in all_methods else []
methods += [m for m in all_methods if m != "marker"]
score_types = scores.split(",")
for score_type in score_types:
if score_type not in SCORE_REGISTRY:
raise ValueError(f"Score type {score_type} not allowed. Allowed types are {SCORE_REGISTRY.keys()}")
if languages:
languages = languages.split(",")
else:
languages = None
benchmark_dataset = datasets.load_dataset(dataset, split="train")
if languages:
benchmark_dataset = benchmark_dataset.filter(lambda x: x["language"] in languages)
artifacts = {
"model_dict": create_model_dict(),
"use_llm": use_llm,
"mathpix_ds": None,
"llamaparse_ds": None,
}
if "mathpix" in methods:
artifacts["mathpix_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_mathpix", split="train")
if "llamaparse" in methods:
artifacts["llamaparse_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_llamaparse", split="train")
if "mistral" in methods:
artifacts["mistral_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_mistral", split="train")
if "olmocr" in methods:
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview",
torch_dtype=torch.bfloat16).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
artifacts["olmocr_model"] = {"model": model, "processor": processor}
print(f"Running benchmark with methods: {methods} and scores: {score_types}")
result = get_method_scores(benchmark_dataset, methods, score_types, artifacts, max_rows=max_rows)
# Display benchmark scoring tables
print_scores(result, out_path, methods, score_types, default_method=methods[0], default_score_type=score_types[0])
# Write to json
with open(out_path / "result.json", "w") as f:
json.dump(result, f)
if out_dataset:
if use_llm:
out_dataset += "_llm"
dataset = build_dataset(benchmark_dataset, result, score_types, max_rows=max_rows)
dataset.push_to_hub(out_dataset, private=True)
if __name__ == "__main__":
main()

@ -0,0 +1,24 @@
from benchmarks.overall.methods.docling import DoclingMethod
from benchmarks.overall.methods.gt import GTMethod
from benchmarks.overall.methods.llamaparse import LlamaParseMethod
from benchmarks.overall.methods.marker import MarkerMethod
from benchmarks.overall.methods.mathpix import MathpixMethod
from benchmarks.overall.methods.mistral import MistralMethod
from benchmarks.overall.methods.olmocr import OlmOCRMethod
from benchmarks.overall.scorers.heuristic import HeuristicScorer
from benchmarks.overall.scorers.llm import LLMScorer
SCORE_REGISTRY = {
"heuristic": HeuristicScorer,
"llm": LLMScorer
}
METHOD_REGISTRY = {
"marker": MarkerMethod,
"gt": GTMethod,
"mathpix": MathpixMethod,
"llamaparse": LlamaParseMethod,
"docling": DoclingMethod,
"olmocr": OlmOCRMethod,
"mistral": MistralMethod
}

@ -0,0 +1,12 @@
from typing import TypedDict, List, Dict
from benchmarks.overall.scorers.schema import BlockScores
AVG_TYPE = Dict[str, Dict[str, Dict[str, List[float]]]]
class FullResult(TypedDict):
scores: Dict[int, Dict[str, Dict[str, BlockScores]]]
averages_by_type: AVG_TYPE
averages_by_block_type: AVG_TYPE
average_times: Dict[str, List[float]]
markdown: Dict[int, Dict[str, str]]

@ -0,0 +1,11 @@
from typing import List
from benchmarks.overall.scorers.schema import BlockScores
class BaseScorer:
def __init__(self):
pass
def __call__(self, sample, gt_markdown: List[str], method_markdown: str) -> BlockScores:
raise NotImplementedError()

@ -0,0 +1,113 @@
import re
import subprocess
import tempfile
from pathlib import Path
import latex2mathml.converter
class MarkdownCleaner:
def __init__(self):
pass
def __call__(self, markdown):
markdown = self.normalize_markdown(markdown) # Use pandoc to normalize
# Replace math expressions with latexml
pattern = r'(?<!\\)\$(?:\$([^$]+)\$\$|\s*([^$\n]+?)\s*\$)'
markdown = re.sub(pattern, self.standardize_math, markdown)
# Replace image urls with a generic tag
pattern = r'!\[(.*?)\]\((https?://[^\s\)]+)\)'
markdown = re.sub(pattern, r'![link]', markdown)
# Clean up stray html tags
markdown = markdown.replace("<br>", "\n")
markdown = re.sub(r"<sub>(.*?)</sub>", r"\1", markdown)
markdown = re.sub(r"<sup>(.*?)</sup>", r"\1", markdown)
markdown = re.sub(r"<span.*?>(.*?)</span>", r"\1", markdown) # Remove span tags and keep content
# Clean up markdown formatting
markdown = re.sub(r"\s+", " ", markdown)
markdown = re.sub(r"\n+", "\n", markdown)
markdown = re.sub("\\.+", ".",
markdown) # Replace repeated periods with a single period, like in table of contents
markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header
markdown = markdown.encode().decode('unicode-escape', errors="ignore") # Decode unicode characters properly
return markdown.strip().lower()
@staticmethod
def normalize_markdown(md_text: str) -> str:
with tempfile.TemporaryDirectory() as tmp_dir:
dirpath = Path(tmp_dir)
input_file = dirpath / 'input.md'
input_file.write_text(md_text, encoding='utf-8')
# Markdown to HTML
html_file = dirpath / 'temp.html'
subprocess.run(
[
'pandoc',
str(input_file),
'-f', 'markdown+tex_math_dollars',
'-t', 'html',
'-o', str(html_file),
'--quiet'
],
check=True
)
# HTML to Markdown
output_file = dirpath / 'output.md'
subprocess.run(
[
'pandoc',
str(html_file),
'-f', 'html',
'-t', 'markdown+tex_math_dollars',
'-o', str(output_file),
'--quiet'
],
check=True
)
# Read back the normalized Markdown
normalized_md = output_file.read_text(encoding='utf-8')
return normalized_md
def standardize_math(self, match):
try:
delim = "$$" if match.group(0).startswith('$$') else "$"
math_content = match.group(1) or match.group(2)
if delim == "$$":
math_content = latex2mathml.converter.convert(math_content)
else:
math_content = self.clean_latex(math_content)
return f'{delim}{math_content}{delim}'
except Exception as e:
print(f"Failed to standardize math expression: {match.group(0)} with error: {e}")
return match.group(0)
@staticmethod
def clean_latex(latex_str):
latex_str = re.sub(r'\s+', ' ', latex_str.strip())
for tag in [r'\\text', r'\\mathrm', r'\\mathbf', r'\\textbf']:
latex_str = re.sub(tag + r'\{([^}]+)\}', r'\1', latex_str)
replacements = {
'\\times': '*',
'\\cdot': '*',
'\\div': '/',
'\\le': '<=',
'\\ge': '>=',
'\\neq': '!=',
'\\to': '\\rightarrow',
}
for old, new in replacements.items():
latex_str = latex_str.replace(old, new)
return latex_str

@ -0,0 +1,105 @@
from typing import List
from rapidfuzz import fuzz
from benchmarks.overall.scorers.clean import MarkdownCleaner
from benchmarks.overall.scorers.schema import BlockScores
from benchmarks.overall.scorers import BaseScorer
class HeuristicScorer(BaseScorer):
def __call__(self, sample, gt_markdown: List[str], method_markdown: str) -> BlockScores:
if not method_markdown:
return {
"score": 0,
"specific_scores": {
"order": 0,
"by_block": [0] * len(gt_markdown)
}
}
# Standardize inputs
gt_markdown = [self.clean_input(block) for block in gt_markdown]
method_markdown = self.clean_input(method_markdown)
alignments = self.find_fuzzy_alignments(method_markdown, gt_markdown)
scores = [alignment["score"] for alignment in alignments]
# Find order score
orders = [alignment["start"] for alignment in alignments]
correct_order = list(range(len(gt_markdown)))
actual_order = sorted(range(len(gt_markdown)), key=lambda x: orders[x])
order_score = self.kendall_tau(correct_order, actual_order)
# Weight score by sequence length
gt_weights = [len(g) for g in gt_markdown]
weighted_scores = [score * weight for score, weight in zip(scores, gt_weights)]
# Weight the score by sequence length
overall_score = sum(weighted_scores) / max(1, sum(gt_weights))
overall_score = overall_score * 0.8 + order_score * 0.2
return {
"score": overall_score,
"specific_scores": {
"order": order_score,
"by_block": scores
},
}
@staticmethod
def kendall_tau(correct_order: List[int], actual_order: List[int]) -> float:
n = len(correct_order)
concordant = 0
discordant = 0
if n <= 1:
return 100
for i in range(n):
for j in range(i + 1, n):
correct_sign = correct_order[i] - correct_order[j]
actual_sign = actual_order[i] - actual_order[j]
if (correct_sign > 0 and actual_sign > 0) or (correct_sign < 0 and actual_sign < 0):
concordant += 1
elif (correct_sign < 0 and actual_sign > 0) or (correct_sign > 0 and actual_sign < 0):
discordant += 1
total_pairs = (n * (n - 1)) // 2
tau = (concordant - discordant) / total_pairs
tau = (tau + 1) / 2 # 0-1 scale
return tau * 100 # 0-100 scale
@staticmethod
def find_fuzzy_alignments(
main_string: str,
substrings: List[str],
threshold: int = 70
) -> List[dict]:
alignments = []
for idx, substr in enumerate(substrings):
result = fuzz.partial_ratio_alignment(substr, main_string, score_cutoff=threshold)
score = 0
dest_start = 0
dest_end = 0
if result:
score = result.score
dest_start = result.dest_start
dest_end = result.dest_end
alignments.append({
"string": substr,
"start": dest_start,
"end": dest_end,
"score": score,
"idx": idx
})
return alignments
@staticmethod
def clean_input(md: str):
cleaner = MarkdownCleaner()
return cleaner(md)

@ -0,0 +1,160 @@
import json
import os
import tempfile
import time
from typing import List
from PIL import Image
from google.genai.errors import APIError
from google import genai
import pypdfium2 as pdfium
from benchmarks.overall.scorers import BaseScorer, BlockScores
from marker.settings import settings
rating_prompt = """
You're a document analysis expert who is comparing some markdown to an image to make sure the markdown is correct. You're rating how effectively the provided markdown represents the full text and formatting in the image provided.
You're given an image, along with the extracted markdown:
- Some parts of the page may have been recognized as images and linked from the markdown, like `![](_page_0_Picture_0.jpeg)`.
- Tables will be formatted as Github flavored markdown.
- Block equations will be in LaTeX.
- The image and markdown may be in any language.
- The markdown is based on the text extracted from the document, and sometimes the document may have had bad OCR applied to it, resulting in gibberish text.
The markdown should fully capture the meaning and formatting of the text in the image. You'll evaluate the markdown based on the image provided.
**Instructions**
Follow this process to evaluate the markdown:
1. Carefully examine the image.
2. Carefully examine the markdown input provided.
3. Compare the image to the markdown representation. Does the markdown representation properly represent the important text and formatting in the image?
4. Assign component scores, as described below.
These are the primary scores:
- Overall - the overall quality of the markdown as compared to the image.
- Text quality - the quality of the text extraction from the image.
- Formatting quality - the quality of the formatting applied to the markdown, as compared to the image.
Depending on which elements are present in the markdown, you will assign element-specific scores.
- Tables - how effectively the tables have been extracted and formatted.
- Forms - how effectively the forms have extracted and formatted.
- Equations - how effectively block equations have been converted to LaTeX.
- Section headers - if all of the section headers have been detected, and the right levels set.
- Lists - if the lists have been properly extracted and formatted.
- Images - if images are identified and placed correctly.
Notes on scoring:
- To get a 5/5, all of the important text from the image must appear in the markdown, and the formatting should be correct (minor mistakes okay). It's okay to omit some text that isn't important to the meaning, like page numbers and chapter headings. If the entire page is an image, it's okay if the markdown is just a link to the image, unless the image would be better represented as text.
- A 3/5 may have small missing text elements from the markdown and/or moderate formatting issues.
- A 1/5 will have major missing text segments from the markdown or completely unreadable formatting.
- Use 0/5 if a field isn't applicable, like if the image doesn't contain a table.
If text that is important to the meaning of the document is missing, do not score higher than 3/5.
Output json, like in the example below.
**Example**
Input
```markdown
# Section 1
This is some *markdown* extracted from a document. Here is a block equation:
$$\frac{ab \cdot x^5 + x^2 + 2 \cdot x + 123}{t}$$
```
Output
```json
{
"image_description": "In the image, there is a section header 'Section 1', followed by some text and a block equation.",
"markdown_description": "In the markdown, there is a section header 'Section 1', followed by some text and a block equation.",
"comparison": "The text and formatting matches the image. There are no formatting or text extraction issues. The equations and section headers are correct.",
"overall": 5,
"text": 5,
"formatting": 5,
"section_headers": 5,
"tables": 0,
"forms": 0,
"equations": 5,
"lists": 0,
"images": 0
}
```
**Input**
```markdown
{{markdown}}
```
**Output**
"""
comparison_keys = ["comparison"]
description_keys = ["image_description", "markdown_description"]
text_keys = comparison_keys + description_keys
score_keys = ["overall", "text", "formatting", "section_headers", "tables", "forms", "equations",
"lists", "images"]
class LLMScorer(BaseScorer):
def __call__(self, sample, gt_markdown: List[str], markdown: str) -> BlockScores:
pdf_bytes = sample["pdf"]
with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
f.write(pdf_bytes)
f.flush()
f.seek(0)
doc = pdfium.PdfDocument(f.name)
img = doc[0].render(scale=96/72).to_pil()
doc.close()
return self.llm_rater(img, markdown)
def llm_rater(self, img: Image.Image, markdown: str) -> BlockScores:
if not markdown:
null_scores = {k: 1 for k in score_keys}
text_scores = {k: "" for k in text_keys}
null_scores.update(text_scores)
return {
"score": 1,
"specific_scores": null_scores
}
req_keys = text_keys + score_keys
properties = {}
for key in req_keys:
content_type = "INTEGER" if key in score_keys else "STRING"
properties[key] = {"type": content_type}
response_schema = {
"required": req_keys,
"properties": properties,
"type": "OBJECT"
}
prompt = rating_prompt.replace("{{markdown}}", markdown)
response = self.llm_response_wrapper([img, prompt], response_schema)
assert all([k in response for k in req_keys]), f"Missing keys in response: {response}"
return {
"score": response["overall"],
"specific_scores": response,
}
def llm_response_wrapper(self, prompt, response_schema, depth=0):
client = genai.Client(
http_options={"timeout": 60000},
vertexai=True,
project=os.getenv("VERTEX_PROJECT_ID"),
location=os.getenv("VERTEX_LOCATION"),
)
try:
responses = client.models.generate_content(
model="gemini-2.0-flash-001",
contents=prompt,
config={
"temperature": 0,
"response_schema": response_schema,
"response_mime_type": "application/json",
},
)
output = responses.candidates[0].content.parts[0].text
return json.loads(output)
except APIError as e:
print(f"Hit Gemini rate limit, waiting 120 seconds")
time.sleep(120)
if depth > 2:
raise e
return self.llm_response_wrapper(prompt, response_schema, depth + 1)

@ -0,0 +1,6 @@
from typing import TypedDict, List, Optional, Dict
class BlockScores(TypedDict):
score: float
specific_scores: Dict[str, float | List[float]]

@ -0,0 +1,48 @@
import json
from PIL import Image
from google import genai
from google.genai import types
from io import BytesIO
from pydantic import BaseModel
from marker.settings import settings
prompt = """
You're an expert document analyst who is good at turning tables in documents into HTML. Analyze the provided image, and convert it to a faithful HTML representation.
Guidelines:
- Keep the HTML simple and concise.
- Only include the <table> tag and contents.
- Only use <table>, <tr>, and <td> tags. Only use the colspan and rowspan attributes if necessary. Do not use <tbody>, <thead>, or <th> tags.
- Make sure the table is as faithful to the image as possible with the given tags.
**Instructions**
1. Analyze the image, and determine the table structure.
2. Convert the table image to HTML, following the guidelines above.
3. Output only the HTML for the table, starting with the <table> tag and ending with the </table> tag.
""".strip()
class TableSchema(BaseModel):
table_html: str
def gemini_table_rec(image: Image.Image):
client = genai.Client(
api_key=settings.GOOGLE_API_KEY,
http_options={"timeout": 60000}
)
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
responses = client.models.generate_content(
model="gemini-2.0-flash",
contents=[types.Part.from_bytes(data=image_bytes.getvalue(), mime_type="image/png"), prompt], # According to gemini docs, it performs better if the image is the first element
config={
"temperature": 0,
"response_schema": TableSchema,
"response_mime_type": "application/json",
},
)
output = responses.candidates[0].content.parts[0].text
return json.loads(output)["table_html"]

@ -0,0 +1,182 @@
from typing import List
import numpy as np
from bs4 import BeautifulSoup
import pypdfium2 as pdfium
from tqdm import tqdm
import base64
import tempfile
from benchmarks.table.gemini import gemini_table_rec
from marker.config.parser import ConfigParser
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.processors.llm.llm_table import LLMTableProcessor
from marker.processors.table import TableProcessor
from marker.renderers.json import JSONBlockOutput
from marker.schema.polygon import PolygonBox
from marker.util import matrix_intersection_area
def extract_tables(children: List[JSONBlockOutput]):
tables = []
for child in children:
if child.block_type == 'Table':
tables.append(child)
elif child.children:
tables.extend(extract_tables(child.children))
return tables
def fix_table_html(table_html: str) -> str:
marker_table_soup = BeautifulSoup(table_html, 'html.parser')
tbody = marker_table_soup.find('tbody')
if tbody:
tbody.unwrap()
for th_tag in marker_table_soup.find_all('th'):
th_tag.name = 'td'
for br_tag in marker_table_soup.find_all('br'):
br_tag.replace_with(marker_table_soup.new_string(''))
marker_table_html = str(marker_table_soup)
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
return marker_table_html
def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
models = create_model_dict()
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
total_unaligned = 0
results = []
iterations = len(dataset)
if max_rows is not None:
iterations = min(max_rows, len(dataset))
for i in tqdm(range(iterations), desc='Converting Tables'):
try:
row = dataset[i]
pdf_binary = base64.b64decode(row['pdf'])
gt_tables = row['tables'] # Already sorted by reading order, which is what marker returns
# Only use the basic table processors
converter = TableConverter(
config=config_parser.generate_config_dict(),
artifact_dict=models,
processor_list=[
"marker.processors.table.TableProcessor",
"marker.processors.llm.llm_table.LLMTableProcessor",
],
renderer=config_parser.get_renderer()
)
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
temp_pdf_file.write(pdf_binary)
temp_pdf_file.seek(0)
marker_json = converter(temp_pdf_file.name).children
doc = pdfium.PdfDocument(temp_pdf_file.name)
page_image = doc[0].render(scale=96/72).to_pil()
doc.close()
if len(marker_json) == 0 or len(gt_tables) == 0:
print(f'No tables detected, skipping...')
total_unaligned += len(gt_tables)
continue
marker_tables = extract_tables(marker_json)
marker_table_boxes = [table.bbox for table in marker_tables]
page_bbox = marker_json[0].bbox
if len(marker_tables) != len(gt_tables):
print(f'Number of tables do not match, skipping...')
total_unaligned += len(gt_tables)
continue
table_images = [
page_image.crop(
PolygonBox.from_bbox(bbox)
.rescale(
(page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
).bbox
)
for bbox
in marker_table_boxes
]
# Normalize the bboxes
for bbox in marker_table_boxes:
bbox[0] = bbox[0] / page_bbox[2]
bbox[1] = bbox[1] / page_bbox[3]
bbox[2] = bbox[2] / page_bbox[2]
bbox[3] = bbox[3] / page_bbox[3]
gt_boxes = [table['normalized_bbox'] for table in gt_tables]
gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
aligned_tables = []
used_tables = set()
unaligned_tables = set()
for table_idx, alignment in enumerate(table_alignments):
try:
max_area = np.max(alignment)
aligned_idx = np.argmax(alignment)
except ValueError:
# No alignment found
unaligned_tables.add(table_idx)
continue
if max_area <= .01:
# No alignment found
unaligned_tables.add(table_idx)
continue
if aligned_idx in used_tables:
# Marker table already aligned with another gt table
unaligned_tables.add(table_idx)
continue
# Gt table doesn't align well with any marker table
gt_table_pct = gt_areas[table_idx] / max_area
if not .85 < gt_table_pct < 1.15:
unaligned_tables.add(table_idx)
continue
# Marker table doesn't align with gt table
marker_table_pct = marker_areas[aligned_idx] / max_area
if not .85 < marker_table_pct < 1.15:
unaligned_tables.add(table_idx)
continue
gemini_html = ""
if use_gemini:
try:
gemini_html = gemini_table_rec(table_images[aligned_idx])
except Exception as e:
print(f'Gemini failed: {e}')
aligned_tables.append(
(marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
)
used_tables.add(aligned_idx)
total_unaligned += len(unaligned_tables)
for marker_table, gt_table, gemini_table in aligned_tables:
gt_table_html = gt_table['html']
# marker wraps the table in <tbody> which fintabnet data doesn't
# Fintabnet doesn't use th tags, need to be replaced for fair comparison
marker_table_html = fix_table_html(marker_table.html)
gemini_table_html = fix_table_html(gemini_table)
results.append({
"marker_table": marker_table_html,
"gt_table": gt_table_html,
"gemini_table": gemini_table_html
})
except pdfium.PdfiumError:
print('Broken PDF, Skipping...')
continue
return results, total_unaligned

@ -0,0 +1,109 @@
""""
TEDS Code Adapted from https://github.com/ibm-aur-nlp/EDD
"""
import distance
from apted import APTED, Config
from apted.helpers import Tree
from lxml import html
from collections import deque
def wrap_table_html(table_html:str)->str:
return f'<html><body>{table_html}</body></html>'
class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
# Sets self.name and self.children
super().__init__(tag, *children)
def bracket(self):
"""Show tree using brackets notation"""
if self.tag == 'td':
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
(self.tag, self.colspan, self.rowspan, self.content)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
return max(map(len, sequences))
def normalized_distance(self, *sequences):
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
return self.normalized_distance(node1.content, node2.content)
return 0.
def tokenize(node):
"""
Tokenizes table cells
"""
global __tokens__
__tokens__.append('<%s>' % node.tag)
if node.text is not None:
__tokens__ += list(node.text)
for n in node.getchildren():
tokenize(n)
if node.tag != 'unk':
__tokens__.append('</%s>' % node.tag)
if node.tag != 'td' and node.tail is not None:
__tokens__ += list(node.tail)
def tree_convert_html(node, convert_cell=False, parent=None):
"""
Converts HTML tree to the format required by apted
"""
global __tokens__
if node.tag == 'td':
if convert_cell:
__tokens__ = []
tokenize(node)
cell = __tokens__[1:-1].copy()
else:
cell = []
new_node = TableTree(node.tag,
int(node.attrib.get('colspan', '1')),
int(node.attrib.get('rowspan', '1')),
cell, *deque())
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != 'td':
for n in node.getchildren():
tree_convert_html(n, convert_cell, new_node)
if parent is None:
return new_node
def similarity_eval_html(pred, true, structure_only=False):
"""
Computes TEDS score between the prediction and the ground truth of a given samples
"""
pred, true = html.fromstring(pred), html.fromstring(true)
if pred.xpath('body/table') and true.xpath('body/table'):
pred = pred.xpath('body/table')[0]
true = true.xpath('body/table')[0]
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
tree_pred = tree_convert_html(pred, convert_cell=not structure_only)
tree_true = tree_convert_html(true, convert_cell=not structure_only)
n_nodes = max(n_nodes_pred, n_nodes_true)
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0

@ -0,0 +1,97 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
from pathlib import Path
from itertools import repeat
from typing import List
import time
import datasets
from tqdm import tqdm
import click
from tabulate import tabulate
import json
from concurrent.futures import ProcessPoolExecutor
from marker.settings import settings
from benchmarks.table.inference import inference_tables
from scoring import wrap_table_html, similarity_eval_html
def update_teds_score(result, prefix: str = "marker"):
prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
score = similarity_eval_html(prediction, ground_truth)
result.update({f'{prefix}_score':score})
return result
@click.command(help="Benchmark Table to HTML Conversion")
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
def main(
result_path: str,
dataset: str,
max_rows: int,
max_workers: int,
use_llm: bool,
table_rec_batch_size: int | None,
use_gemini: bool = False
):
start = time.time()
dataset = datasets.load_dataset(dataset, split='train')
dataset = dataset.shuffle(seed=0)
results, total_unaligned = inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini)
print(f"Total time: {time.time() - start}.")
print(f"Could not align {total_unaligned} tables from fintabnet.")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
marker_results = list(
tqdm(
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
)
)
avg_score = sum([r["marker_score"] for r in marker_results]) / len(marker_results)
headers = ["Avg score", "Total tables"]
data = [f"{avg_score:.3f}", len(marker_results)]
gemini_results = None
if use_gemini:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
gemini_results = list(
tqdm(
executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores',
total=len(results)
)
)
avg_gemini_score = sum([r["gemini_score"] for r in gemini_results]) / len(gemini_results)
headers.append("Avg Gemini score")
data.append(f"{avg_gemini_score:.3f}")
table = tabulate([data], headers=headers, tablefmt="github")
print(table)
print("Avg score computed by comparing marker predicted HTML with original HTML")
results = {
"marker": marker_results,
"gemini": gemini_results
}
out_path = Path(result_path)
out_path.mkdir(parents=True, exist_ok=True)
with open(out_path / "table.json", "w+") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {out_path}.")
if __name__ == '__main__':
main()

@ -0,0 +1,40 @@
import time
import torch
import click
import pypdfium2 as pdfium
from tqdm import tqdm
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
@click.command(help="Benchmark PDF to MD conversion throughput.")
@click.argument("pdf_path", type=str)
def main(pdf_path):
print(f"Converting {pdf_path} to markdown...")
pdf = pdfium.PdfDocument(pdf_path)
page_count = len(pdf)
pdf.close()
model_dict = create_model_dict()
torch.cuda.reset_peak_memory_stats()
times = []
for i in tqdm(range(10), desc="Benchmarking"):
block_converter = PdfConverter(
artifact_dict=model_dict,
config={"disable_tqdm": True}
)
start = time.time()
block_converter(pdf_path)
total = time.time() - start
times.append(total)
max_gpu_vram = torch.cuda.max_memory_allocated() / 1024 ** 3
print(f"Converted {page_count} pages in {sum(times)/len(times):.2f} seconds.")
print(f"Max GPU VRAM: {max_gpu_vram:.2f} GB")
if __name__ == "__main__":
main()

@ -0,0 +1,33 @@
import json
import argparse
def verify_scores(file_path):
with open(file_path, 'r') as file:
data = json.load(file)
raw_scores = [data["scores"][k] for k in data["scores"]]
marker_scores = [r["marker"]["heuristic"]["score"] for r in raw_scores]
marker_score = sum(marker_scores) / len(marker_scores)
if marker_score < 90:
raise ValueError("Marker score below 90")
def verify_table_scores(file_path):
with open(file_path, 'r') as file:
data = json.load(file)
avg = sum([r["marker_score"] for r in data["marker"]]) / len(data)
if avg < 0.7:
raise ValueError("Average score is below the required threshold of 0.7")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify benchmark scores")
parser.add_argument("file_path", type=str, help="Path to the json file")
parser.add_argument("--type", type=str, help="Type of file to verify", default="marker")
args = parser.parse_args()
if args.type == "marker":
verify_scores(args.file_path)
elif args.type == "table":
verify_table_scores(args.file_path)

@ -0,0 +1,4 @@
from marker.scripts.chunk_convert import chunk_convert_cli
if __name__ == "__main__":
chunk_convert_cli()

@ -0,0 +1,4 @@
from marker.scripts.convert import convert_cli
if __name__ == "__main__":
convert_cli()

@ -0,0 +1,4 @@
from marker.scripts.convert_single import convert_single_cli
if __name__ == "__main__":
convert_single_cli()

@ -0,0 +1,3 @@
latex
pdfs
references

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

@ -0,0 +1,270 @@
# An Aggregated Multicolumn Dilated Convolution Network for Perspective-Free Counting
Diptodip Deb Georgia Institute of Technology diptodipdeb@gatech.edu
# Abstract
*We propose the use of dilated filters to construct an aggregation module in a multicolumn convolutional neural network for perspective-free counting. Counting is a common problem in computer vision (e.g. traffic on the street or pedestrians in a crowd). Modern approaches to the counting problem involve the production of a density map via regression whose integral is equal to the number of objects in the image. However, objects in the image can occur at different scales (e.g. due to perspective effects) which can make it difficult for a learning agent to learn the proper density map. While the use of multiple columns to extract multiscale information from images has been shown before, our approach aggregates the multiscale information gathered by the multicolumn convolutional neural network to improve performance. Our experiments show that our proposed network outperforms the state-of-the-art on many benchmark datasets, and also that using our aggregation module in combination with a higher number of columns is beneficial for multiscale counting.*
# 1. Introduction
Learning to count the number of objects in an image is a deceptively difficult problem with many interesting applications, such as surveillance [\[20\]](#page-8-0), traffic monitoring [\[14\]](#page-8-1) and medical image analysis [\[22\]](#page-8-2). In many of these application areas, the objects to be counted vary widely in appearance, size and shape, and labeled training data is typically sparse. These factors pose a significant computer vision and machine learning challenge.
Lempitsky et al. [\[15\]](#page-8-3) showed that it is possible to learn to count without learning to explicitly detect and localize individual objects. Instead, they propose learning to predict a density map whose integral over the image equals the number of objects in the image. This approach has been adopted by many later works (Cf. [\[18,](#page-8-4) [28\]](#page-9-0)).
However, in many counting problems, such as those
Jonathan Ventura University of Colorado Colorado Springs jventura@uccs.edu
counting cells in a microscope image, pedestrians in a crowd, or vehicles in a traffic jam, regressors trained on a single image scale are not reliable [\[18\]](#page-8-4). This is due to a variety of challenges including overlap of objects and perspective effects which cause significant variance in object shape, size and appearance.
The most successful recent approaches address this issue by explicitly incorporating multi-scale information in the network [\[18,](#page-8-4)[28\]](#page-9-0). These approaches either combine multiple networks which take input patches of different sizes [\[18\]](#page-8-4) or combine multiple filtering paths ("columns") which have different size filters [\[28\]](#page-9-0).
Following on the intuition that multiscale integration is key to achieving good counting performance, we propose to incorporate dilated filters [\[25\]](#page-8-5) into a multicolumn convolutional neural network design [\[28\]](#page-9-0). Dilated filters exponentially increase the network's receptive field without an exponential increase in parameters, allowing for efficient use of multiscale information. Convolutional neural networks with dilated filters have proven to provide competitive performance in image segmentation where multiscale analysis is also critical [\[25,](#page-8-5) [26\]](#page-8-6). By incorporating dilated filters into the multicolumn network design, we greatly increase the ability of the network to selectively aggregate multiscale information, without the need for explicit perspective maps during training and testing. We propose the "aggregated multicolumn dilated convolution network" or AMDCN which uses dilations to aggregate multiscale information. Our extensive experimental evaluation shows that this proposed network outperforms previous methods on many benchmark datasets.
# 2. Related Work
Counting using a supervised regressor to formulate a density map was first shown by [\[15\]](#page-8-3). In this paper, Lempitsky et al. show that the minimal annotation of a single dot blurred by a Gaussian kernel produces a sufficient density map to train a network to count. All of the counting methods that we examine as well as the method we use in
![](_page_1_Figure_0.jpeg)
<span id="page-1-0"></span>Figure 1. Fully convolutional architecture diagram (not to scale). Arrows show separate columns that all take the same input. At the end of the columns, the feature maps are merged (concatenated) together and passed to another series of dilated convolutions: the aggregator, which can aggregate the multiscale information collected by the columns [\[25\]](#page-8-5). The input image is I with C channels. The output single channel density map is D, and integrating over this map (summing the pixels) results in the final count. Initial filter sizes are labeled with brackets or lines. Convolution operations are shown as flat rectangles, feature maps are shown as prisms. The number below each filter represents the dilation rate (1 means no dilation).
our paper follow this method of producing a density map via regression. This is particularly advantageous because a sufficiently accurate regressor can also locate the objects in the image via this method. However, the Lempitsky paper ignores the issue of perspective scaling and other scaling issues. The work of [\[27\]](#page-8-7) introduces CNNs (convolutional neural networks) for the purposes of crowd counting, but performs regression on similarly scaled image patches.
These issues are addressed by the work of [\[18\]](#page-8-4). Rubio et al. show that a fully convolutional neural network can be used to produce a supervised regressor that produces density maps as in [\[15\]](#page-8-3). They further demonstrate a method dubbed HydraCNN which essentially combines multiple convolutional networks that take in differently scaled image patches in order to incorporate multiscale, global information from the image. The premise of this method is that a single regressor will fail to accurately represent the difference in values of the features of an image caused by perspective shifts (scaling effects) [\[18\]](#page-8-4).
However, the architectures of both [\[18\]](#page-8-4) and [\[27\]](#page-8-7) are not fully convolutional due to requiring multiple image patches and, as discussed in [\[25\]](#page-8-5), the experiments of [\[11,](#page-8-8) [17\]](#page-8-9) and [\[9,](#page-8-10) [12,](#page-8-11) [16\]](#page-8-12) leave it unclear as to whether rescaling patches of the image is truly necessary in order to solve dense prediction problems via convolutional neural networks. Moreover, these approaches seem to saturate in performance at three columns, which means the network is extracting information from fewer scales. The work of [\[25\]](#page-8-5) proposes the use of dilated convolutions as a simpler alternative that does not require sampling of rescaled image patches to provide global, scale-aware information to the network. A fully convolutional approach to multiscale counting has been proposed by [\[28\]](#page-9-0), in which a multicolumn convolutional network gathers features of different scales by using convolutions of increasing kernel sizes from column to column instead of scaling image patches. Further, DeepLab has used dilated convolutions in multiple columns to extract scale information for segmentation [\[8\]](#page-8-13). We build on these approaches with our aggregator module as described in Section [3.1,](#page-2-0) which should allow for extracting information from more scales.
It should be noted that other methods of counting exist, including training a network to recognize deep object features via only providing the counts of the objects of interest in an image [\[21\]](#page-8-14) and using CNNs (convolutional neural networks) along with boosting in order to improve the results
![](_page_2_Picture_0.jpeg)
Figure 2. UCF sample results. Left: input counting image. Middle: Ground truth density map. Right: AMDCN prediction of density map on test image. The network never saw these images during training. All density maps are one channel only (i.e. grayscale), but are colored here for clarity.
<span id="page-2-2"></span>of regression for production of density maps [\[24\]](#page-8-15). In the same spirit, [\[4\]](#page-8-16) combines deep and shallow convolutions within the same network, providing accurate counting of dense objects (e.g. the UCF50 crowd dataset).
In this paper, however, we aim to apply the dilated convolution method of [\[25\]](#page-8-5), which has shown to be able to incorporate multiscale perspective information without using multiple inputs or a complicated network architecture, as well as the multicolumn approach of [\[8,](#page-8-13) [28\]](#page-9-0) to aggregate multiscale information for the counting problem.
# 3. Method
### <span id="page-2-0"></span>3.1. Dilated Convolutions for Multicolumn Networks
We propose the use of dilated convolutions as an attractive alternative to the architecture of the HydraCNN [\[18\]](#page-8-4), which seems to saturate in performance at 3 or more columns. We refer to our proposed network as the aggregated multicolumn dilated convolution network1, henceforth shortened as the AMDCN. The architecture of the AMDCN is inspired by the multicolumn counting network of [\[28\]](#page-9-0). Extracting features from multiple scales is a good idea when attempting to perform perspective-free counting and increasing the convolution kernel size across columns is an efficient method of doing so. However, the number of parameters increases exponentially as larger kernels are used in these columns to extract features at larger scales. Therefore, we propose using dilated convolutions rather than larger kernels.
Dilated convolutions, as discussed in [\[25\]](#page-8-5), allow for the exponential increase of the receptive field with a linear increase in the number of parameters with respect to each hidden layer.
In a traditional 2D convolution, we define a real valued function $F : \mathbb{Z}^2 \rightarrow \mathbb{R}$, an input $\Omega_r = [-r, r]^2 \in \mathbb{Z}^2$, and a filter function $k : \Omega_r \rightarrow \mathbb{R}$. In this case, a convolution operation as defined in [\[25\]](#page-8-5) is given by
$$(F*k)(\mathbf{p}) = \sum_{\mathbf{s}+\mathbf{t}=\mathbf{p}} F(\mathbf{s})k(\mathbf{t}).\tag{1}$$
A dilated convolution is essentially a generalization of the traditional 2D convolution that allows the operation to skip some inputs. This enables an increase in the size of the filter (i.e. the size of the receptive field) without losing resolution. Formally, we define from [\[25\]](#page-8-5) the dilated convolution as
$$(F \ast_l k)(\mathbf{p}) = \sum_{\mathbf{s} + l\mathbf{t} = \mathbf{p}} F(\mathbf{s}) k(\mathbf{t}) \tag{2}$$
where l is the index of the current layer of the convolution.
Using dilations to construct the aggregator in combination with the multicolumn idea will allow for the construction of a network with more than just 3 or 4 columns as in [\[28\]](#page-9-0) and [\[8\]](#page-8-13), because the aggregator should prevent the saturation of performance with increasing numbers of columns. Therefore the network will be able to extract useful features from more scales. We take advantage of dilations within the columns as well to provide large receptive fields with fewer parameters.
Looking at more scales should allow for more accurate regression of the density map. However, because not all scales will be relevant, we extend the network beyond a simple 1 × 1 convolution after the merged columns. Instead, we construct a second part of the network, the aggregator, which sets our method apart from [\[28\]](#page-9-0), [\[8\]](#page-8-13), and other multicolumn networks. This aggregator is another series of dilated convolutions that should appropriately consolidate the multiscale information collected by the columns. This is a capability of dilated convolutions observed by [\[25\]](#page-8-5). While papers such as [\[28\]](#page-9-0) and [\[8\]](#page-8-13) have shown that multiple columns and dilated columns are useful in extracting multiscale information, we argue in this paper that the simple aggregator module built using dilated convolutions is able to effectively make use multiscale information from multiple columns. We show compelling evidence for these claims in Section 4.5.
The network as shown in Figure [1](#page-1-0) contains 5 columns. Note that dilations allow us to use more columns for counting than [\[28\]](#page-9-0) or [\[8\]](#page-8-13). Each column looks at a larger scale than the previous (the exact dilations can also be seen in Figure [1](#page-1-0)). There are 32 feature maps for each convolution, and all inputs are zero padded prior to each convolution in order to maintain the same data shape from input to output. That is, an image input to this network will result in a density map of the same dimensions. All activations in the specified network are ReLUs. Our input pixel values are floating point 32 bit values from 0 to 1. We center our inputs at 0 by subtracting the per channel mean from each channel. When
<span id="page-2-1"></span>1 Implementation available on [https://github.com/](https://github.com/diptodip/counting) [diptodip/counting](https://github.com/diptodip/counting).
training, we use a scaled mean absolute error for our loss function:
$$L = \frac{1}{n} \sum_{i=1}^{n} |\hat{y}_i - \gamma y_i| \tag{3}$$
where $\n \gamma\n $ is the scale factor, $\n \hat{y}_i\n $ is the prediction, $\n y_i\n $ is the true value, and $n$ is the number of pixels. We use a scaled mean absolute error because the target values are so small that it is numerically unstable to regress to these values. At testing time, when retrieving the output density map from the network, we scale the pixel values by $\gamma^{-1}$ to obtain the correct value. This approach is more numerically stable and avoids having the network learn to output only zeros by weighting the nonzero values highly. For all our datasets, we set $\gamma = 255$.
#### 3.2. Experiments
We evaluated the performance of dilated convolutions against various counting methods on a variety of common counting datasets: UCF50 crowd data, TRANCOS traffic data [\[18\]](#page-8-4), UCSD crowd data [\[5\]](#page-8-17), and WorldExpo crowd data [\[27\]](#page-8-7). For each of these data sets, we used labels given by the corresponding density map for each image. An example of this is shown in Figure [2.](#page-2-2) We have performed experiments on the four different splits of the UCSD data as used in [\[18\]](#page-8-4) and the split of the UCSD data as used in [\[28\]](#page-9-0) (which we call the original split). We also evaluated the performance of our network on the TRANCOS traffic dataset [\[14\]](#page-8-1). We have also experimented with higher density datasets for crowd counting, namely WorldExpo and UCF.
We have observed that multicolumn dilations produce density maps (and therefore counts) that often have lower loss than those of HydraCNN [\[18\]](#page-8-4) and [\[28\]](#page-9-0). We measure density map regression loss via a scaled mean absolute error loss during training. We compare accuracy of the counts via mean absolute error for the crowd datasets and the GAME metric in the TRANCOS dataset as explained in Section [3.2.2.](#page-3-0) Beyond the comparison to HydraCNN, we will also compare to other recent convolutional counting methods, especially those of [\[21\]](#page-8-14), [\[24\]](#page-8-15), and [\[4\]](#page-8-16) where possible.
For all datasets, we generally use patched input images and ground truth density maps produced by summing a Gaussian of a fixed size ($σ$) for each object for training. This size varies from dataset to dataset, but remains constant within a dataset with the exception of cases in which a perspective map is used. This is explained per dataset. All experiments were performed using Keras with the Adam optimizer [\[10\]](#page-8-18). The learning rates used are detailed per dataset. For testing, we also use patches that can either be directly pieced together or overlapped and averaged except in the case of UCF, for which we run our network on the full image.
Furthermore, we performed a set of experiments in which we varied the number of columns from 1 to 5 (simply by including or not including the columns as specified in Figure [1,](#page-1-0) starting with the smallest filter column and adding larger filter columns one by one). Essentially, the network is allowed to extract information at larger and larger scales in addition to the smaller scales as we include each column. We then performed the same set of experiments, varying the number of columns, but with the aggregator module removed. We perform these experiments on the original split of UCSD as specified in Section [3.2.3](#page-4-0) and [\[5\]](#page-8-17), the TRAN-COS dataset, and the WorldExpo dataset because these are relatively large and well defined datasets. We limit the number of epochs to 10 for all of these sets of experiments in order to control for the effect of learning time, and also compare all results using MAE for consistency. These experiments are key to determining the efficacy of the aggregator in effectively combining multiscale information and in providing evidence to support the use of multiple columns to extract multiscale information from images. We report the results of these ablation studies in Section [4.5.](#page-5-0)
#### 3.2.1 UCF50 Crowd Counting
UCF is a particularly challenging crowd counting dataset. There are only 50 images in the whole dataset and they are all of varying sizes and from different scenes. The number of people also varies between images from less than 100 to the thousands. The average image has on the order of 1000 people. The difficulty is due to the combination of the very low number of images in the dataset and the fact that the images are all of varying scenes, making high quality generalization crucial. Furthermore, perspective effects are particularly noticeable for many images in this dataset. Despite this, there is no perspective information available for this dataset.
We take 1600 random patches of size $150 \,\times\, 150$ for the training. For testing, we do not densely scan the image as in [\[18\]](#page-8-4) but instead test on the whole image. In order to standardize the image sizes, we pad each image out with zeros until all images are $1024 \times 1024$. We then suppress output in the regions where we added padding when testing. This provides a cleaner resulting density map for these large crowds. The ground truth density maps are produced by annotating each object with a Gaussian of $\sigma = 15$.
#### <span id="page-3-0"></span>3.2.2 TRANCOS Traffic Counting
TRANCOS is a traffic counting dataset that comes with its own metric [\[14\]](#page-8-1). This metric is known as $GAME$, which stands for Grid Average Mean absolute Error. $GAME $splits a given density map into $4^{L}$ grids, or subarrays, and obtains a mean absolute error within each grid separately. The value of $L$ is a parameter chosen by the user. These individual errors are summed to obtain the final error for a particular image. The intuition behind this metric is that it is desirable to penalize a density map whose overall count might match the ground truth, but whose shape does not match the ground truth [\[14\]](#page-8-1). More formally, we define
$$GAME(L) = \frac{1}{N} \cdot \sum_{n=1}^{N} \left( \sum_{l=1}^{4^L} |e_n^l - t_n^l| \right) \qquad (4)$$
where $N$ refers to the number of images, $L$ is the level parameter for $GAME$, $e_{n}^{l}$ is the predicted or estimated count in region $l$ of image $n$ and $t_{n}^{l}$ is the ground truth count in region $l$ of image $n$ [\[14\]](#page-8-1).
For training this dataset, we take 1600 randomly sampled patches of size 80 × 80. For testing this dataset, we take 80 × 80 non-overlapping patches which we can stitch back together into the full-sized 640 × 480 images. We trained the AMDCN network with density maps produced with a Gaussian of $σ$ = 15 as specified in [\[18\]](#page-8-4).
#### <span id="page-4-0"></span>3.2.3 UCSD Crowd Counting
The UCSD crowd counting dataset consists of frames of video of a sidewalk. There are relatively few people in view at any given time (approximately 25 on average). Furthermore, because the dataset comes from a video, there are many nearly identical images in the dataset. For this dataset, there have been two different ways to split the data into train and test sets. Therefore, we report results using both methods of splitting the data. The first method consists of four different splits: maximal, downscale, upscale, and minimal. Minimal is particularly challenging as the train set contains only 10 images. Moreover, upscale appears to be the easiest for the majority of methods [\[18\]](#page-8-4). The second method of splitting this data is much more succinct, leaving 1200 images in the testing set and 800 images in the training set [\[28\]](#page-9-0). This split comes from the original paper, so we call it the original split [\[5\]](#page-8-17).
For this dataset, each object is annotated with a 2D Gaussian of covariance $\Sigma = 8 \cdot 1_{2\times2}$. The ground truth map is produced by summing these. When we make use of the perspective maps provided, we divide $\Sigma$ by the perspective map value at that pixel x, represented by $M(x)$. The provided perspective map for UCSD contains both a horizontal and vertical direction so we take the square root of the provided combined value. For training, we take 1600 random 79 × 119 pixel patches and for testing, we split each test image up into quadrants (which have dimension 79 × 119). There are two different ways to split the dataset into training and testing sets. We have experimented on the split that gave [\[18\]](#page-8-4) the best results as well as the split used in [\[28\]](#page-9-0).
First, we split the dataset into four separate groups of training and testing sets as used in [\[18\]](#page-8-4) and originally defined by [\[20\]](#page-8-0). These groups are "upscale," "maximal," "minimal," and "downscale." We see in Table [3](#page-6-0) that the "upscale" split and "downscale" split give us state of the art results on counting for this dataset. For this experiment, we sampled 1600 random patches of size 119 × 79 pixels (width and height respectively) for the training set and split the test set images into 119 × 79 quadrants that could be reconstructed by piecing them together without overlap. We also added left-right flips of each image to our training data.
We then evaluate the original split. For this experiment, we similarly sampled 1600 random patches of size 119 × 79 pixels (width and height respectively) for the training set and split the test set images into 119 × 79 quadrants that could be reconstructed by piecing them together without overlap.
#### 3.2.4 WorldExpo '10 Crowd Counting
The WorldExpo dataset [\[27\]](#page-8-7) contains a larger number of people (approximately 50 on average, which is double that of UCSD) and contains images from multiple locations. Perspective effects are also much more noticeable in this dataset as compared to UCSD. These qualities of the dataset serve to increase the difficulty of counting. Like UCSD, the WorldExpo dataset was constructed from frames of video recordings of crowds. This means that, unlike UCF, this dataset contains a relatively large number of training and testing images. We experiment on this dataset with and without perspective information.
Without perspective maps, we generate label density maps for this dataset in the same manner as previously described: a 2D Gaussian with $\sigma = 15$. We take 16000 150 × 150 randomly sampled patches for training. For testing, we densely scan the image, producing 150 × 150 patches at a stride of 100.
When perspective maps are used, however, we follow the procedure as described in [\[27\]](#page-8-7), which involves estimating a “crowd density distribution kernel” as the sum of two 2D Gaussians: a symmetric Gaussian for the head and an ellipsoid Gaussian for the body. These are scaled by the perspective map $M$ provided, where $M(x)$ gives the number of pixels that represents a meter at pixel $x$ [\[27\]](#page-8-7). Note that the meaning of this perspective map is distinct from the meaning of the perspective map provided for the UCSD dataset. Using this information, the density contribution from a person with head pixel $x$ is given by the following sum of normalized Gaussians:
$$D_{\mathbf{x}} = \frac{1}{||Z||} (\mathcal{N}_h(\mathbf{x}, \sigma_h) + \mathcal{N}_b(\mathbf{x}_b, \Sigma_b)) \qquad (5)$$
where $x_b$ is the center of the body, which is 0.875 meters down from the head on average, and can be determined from the perspective map $M$ and the head center $x$ [\[27\]](#page-8-7). We sum these Gaussians for each person to pro-
| Method | MAE |
|--------------|--------|
| AMDCN | 290.82 |
| Hydra2s [18] | 333.73 |
| MCNN [28] | 377.60 |
| [27] | 467.00 |
| [23] | 295.80 |
| [3] | 318.10 |
<span id="page-5-1"></span>Table 1. Mean absolute error of various methods on UCF crowds
duce the final density map. We set $\sigma = 0.2M(\mathbf{x})$ for $N_h $and $\sigma_x = 0.2M(\mathbf{x}), \sigma_y = 0.5M(\mathbf{x})$ for $\Sigma_b$ in $N_b$.
# 4. Results
#### 4.1. UCF Crowd Counting
The UCF dataset is particularly challenging due to the large number of people in the images, the variety of the scenes, as well as the low number of training images. We see in Figure [2](#page-2-2) that because the UCF dataset has over 1000 people on average in each image, the shapes output by the network in the density map are not as well defined or separated as in the UCSD dataset.
We report a state of the art result on this dataset in Table [1,](#page-5-1) following the standard protocol of 5-fold cross validation. Our MAE on the dataset is 290.82, which is approximately 5 lower than the previous state of the art, HydraCNN [\[18\]](#page-8-4). This is particularly indicative of the power of an aggregated multicolumn dilation network. Despite not making use of perspective information, the AMDCN is still able to produce highly accurate density maps for UCF.
#### 4.2. TRANCOS Traffic Counting
Our network performs very well on the TRANCOS dataset. Indeed, as confirmed by the GAME score, AMDCN produces the most accurate count and shape combined as compared to other methods. Table 2 shows that we achieve state of the art results as measured by the GAME metric [\[14\]](#page-8-1) across all levels.
#### 4.3. UCSD Crowd Counting
Results are shown in Table [3](#page-6-0) and Figure [3.](#page-6-1) We see that the "original" split as defined by the creators of the dataset in [\[5\]](#page-8-17) and used in [\[28\]](#page-9-0) gives us somewhat worse results for counting on this dataset. Results were consistent over multiple trainings. Again, including the perspective map does not seem to increase performance on this dataset. Despite this, we see in Table [3](#page-6-0) and Figure [3](#page-6-1) that the results are comparable to the state of the art. In fact, for two of the splits, our proposed network beats the state of the art. For the upscale split, the AMDCN is the state of the art by a large relative margin. This is compelling because it shows that accurate perspective-free counting can be achieved without
| Method | GAME<br>(L=0) | GAME<br>(L=1) | GAME<br>(L=2) | GAME<br>(L=3) |
|-------------------------------------------|---------------|---------------|---------------|---------------|
| AMDCN | <b>9.77</b> | <b>13.16</b> | <b>15.00</b> | <b>15.87</b> |
| [18] | 10.99 | 13.75 | 16.69 | 19.32 |
| [15] + SIFT<br>from [14] | 13.76 | 16.72 | 20.72 | 24.36 |
| [13] + RGB<br>Norm + Filters<br>from [14] | 17.68 | 19.97 | 23.54 | 25.84 |
| HOG-2<br>from [14] | 13.29 | 18.05 | 23.65 | 28.41 |
<span id="page-5-2"></span>Table 2. Mean absolute error of various methods on TRANCOS traffic
creating image pyramids or requiring perspective maps as labels using the techniques presented by the AMDCN.
#### 4.4. WorldExpo '10 Crowd Counting
Our network performs reasonably well on the more challenging WorldExpo dataset. While it does not beat the state of the art, our results are comparable. What is more, we do not need to use the perspective maps to obtain these results. As seen in Table [4,](#page-7-1) the AMDCN is capable of incorporating the perspective effects without scaling the Gaussians with perspective information. This shows that it is possible to achieve counting results that approach the state of the art with much simpler labels for the counting training data.
#### <span id="page-5-0"></span>4.5. Ablation Studies
We report the results of the ablation studies in Figure [4.](#page-7-2) We note from these plots that while there is variation in performance, a few trends stand out. Most importantly, the lowest errors are consistently with a combination of a larger number of columns and including the aggregator module. Notably for the TRANCOS dataset, including the aggregator consistently improves performance. Generally, the aggregator tends to decrease the variance in performance of the network. Some of the variance that we see in the plots can be explained by: (1) for lower numbers of columns, including an aggregator is not as likely to help as there is not much separation of multiscale information across columns and (2) for the UCSD dataset, there is less of a perspective effect than TRANCOS and WorldExpo so a simpler network is more likely to perform comparably to a larger network. These results verify the notion that using more columns increases accuracy, and also support our justification for the use of the aggregator module.
![](_page_6_Figure_0.jpeg)
(a) UCSD upscale split. (b) UCSD original split.
<span id="page-6-1"></span>Figure 3. UCSD crowd counting dataset. Both plots show comparisons of predicted and ground truth counts over time. While AMDCN does not beat the state of the art on the original split, the predictions still follow the true counts reasonably. The jump in the original split is due to that testing set including multiple scenes of highly varying counts.
| Method | maximal | downscale | upscale | minimal | original |
|-----------------------------------------|---------|-----------|---------|---------|----------|
| AMDCN (without perspective information) | 1.63 | 1.43 | 0.63 | 1.71 | 1.74 |
| AMDCN (with perspective information) | 1.60 | 1.24 | 1.37 | 1.59 | 1.72 |
| [18] (with perspective information) | 1.65 | 1.79 | 1.11 | 1.50 | - |
| [18] (without perspective information) | 2.22 | 1.93 | 1.37 | 2.38 | - |
| [15] | 1.70 | 1.28 | 1.59 | 2.02 | - |
| [13] | 1.70 | 2.16 | 1.61 | 2.20 | - |
| [19] | 1.43 | 1.30 | 1.59 | 1.62 | - |
| [2] | 1.24 | 1.31 | 1.69 | 1.49 | - |
| [27] | 1.70 | 1.26 | 1.59 | 1.52 | 1.60 |
| [28] | - | - | - | - | 1.07 |
| [1, 28] | - | - | - | - | 2.16 |
| [7] | - | - | - | - | 2.25 |
| [5] | - | - | - | - | 2.24 |
| [6] | - | - | - | - | 2.07 |
<span id="page-6-0"></span>Table 3. Mean absolute error of various methods on UCSD crowds
# 5. Conclusion
#### 5.1. Summary
We have proposed the use of aggregated multicolumn dilated convolutions, the AMDCN, as an alternative to the HydraCNN [\[18\]](#page-8-4) or multicolumn CNN [\[28\]](#page-9-0) for the vision task of counting objects in images. Inspired by the multicolumn approach to multiscale problems, we also employ dilations to increase the receptive field of our columns. We then aggregate this multiscale information using another series of dilated convolutions to enable a wide network and detect features at more scales. This method takes advantage of the ability of dilated convolutions to provide exponentially increasing receptive fields. We have performed experiments on the challenging UCF crowd counting dataset, the TRANCOS traffic dataset, multiple splits of the UCSD crowd counting dataset, and the WorldExpo crowd counting dataset.
![](_page_7_Figure_0.jpeg)
<span id="page-7-2"></span>Figure 4. Ablation studies on various datasets in which the number of columns is varied and the aggregator is included or not included. The results generally support the use of more columns and an aggregator module.
| Method | MAE |
|--------------------------------------------|-------------|
| AMDCN (without perspective information) | 16.6 |
| AMDCN (with perspective information) | 14.9 |
| LBP+RR [28] (with perspective information) | 31.0 |
| MCNN [28] (with perspective information) | <b>11.6</b> |
| [27] (with perspective information) | 12.9 |
<span id="page-7-1"></span>Table 4. Mean absolute error of various methods on WorldExpo crowds
We obtain superior or comparable results in most of these datasets. The AMDCN is capable of outperforming these approaches completely especially when perspective information is not provided, as in UCF and TRANCOS. These results show that the AMDCN performs surprisingly well and is also robust to scale effects. Further, our ablation study of removing the aggregator network shows that using more columns and an aggregator provides the best accuracy for counting — especially so when there is no perspective information.
#### 5.2. Future Work
In addition to an analysis of performance on counting, a density regressor can also be used to locate objects in the image. As mentioned previously, if the regressor is accurate and precise enough, the resulting density map can be used to locate the objects in the image. We expect that in order to do this, one must regress each object to a single point rather than a region specified by a Gaussian. Perhaps this might be accomplished by applying non-maxima suppression to the final layer activations.
Indeed, the method of applying dilated filters to a multicolumn convolutional network in order to enable extracting features of a large number of scales can be applied to various other dense prediction tasks, such as object segmentation at multiple scales or single image depth map prediction. Though we have only conducted experiments on counting and used 5 columns, the architecture presented can be extended and adapted to a variety of tasks that require information at multiple scales.
# Acknowledgment
This material is based upon work supported by the National Science Foundation under Grant No. 1359275 and 1659788. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views of the National Science Foundation. Furthermore, we acknowledge Kyle Yee and Sridhama Prakhya for their helpful conversations and insights during the research process.
# References
- <span id="page-7-4"></span>[1] S. An, W. Liu, and S. Venkatesh. Face recognition using kernel ridge regression. In *Computer Vision and Pattern Recognition, 2007. CVPR'07. IEEE Conference on*, pages 17. IEEE, 2007.
- <span id="page-7-3"></span>[2] C. Arteta, V. Lempitsky, J. A. Noble, and A. Zisserman. Interactive object counting. In *European Conference on Computer Vision*, pages 504518. Springer, 2014.
- <span id="page-7-0"></span>[3] D. Babu Sam, S. Surya, and R. Venkatesh Babu. Switching convolutional neural network for crowd
counting. In *Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition*, pages 57445752, 2017.
- <span id="page-8-16"></span>[4] L. Boominathan, S. S. Kruthiventi, and R. V. Babu. Crowdnet: A deep convolutional network for dense crowd counting. In *Proceedings of the 2016 ACM on Multimedia Conference*, pages 640644. ACM, 2016.
- <span id="page-8-17"></span>[5] A. B. Chan, Z.-S. J. Liang, and N. Vasconcelos. Privacy preserving crowd monitoring: Counting people without people models or tracking. In *Computer Vision and Pattern Recognition, 2008. CVPR 2008. IEEE Conference on*, pages 17. IEEE, 2008.
- <span id="page-8-23"></span>[6] K. Chen, S. Gong, T. Xiang, and C. Change Loy. Cumulative attribute space for age and crowd density estimation. In *Proceedings of the IEEE conference on computer vision and pattern recognition*, pages 2467 2474, 2013.
- <span id="page-8-22"></span>[7] K. Chen, C. C. Loy, S. Gong, and T. Xiang. Feature mining for localised crowd counting.
- <span id="page-8-13"></span>[8] L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, and A. L. Yuille. Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. *IEEE Transactions on Pattern Analysis and Machine Intelligence*, 2017.
- <span id="page-8-10"></span>[9] L.-C. Chen, Y. Yang, J. Wang, W. Xu, and A. L. Yuille. Attention to scale: Scale-aware semantic image segmentation. In *Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition*, pages 36403649, 2016.
- <span id="page-8-18"></span>[10] F. Chollet et al. Keras. [https://github.com/](https://github.com/fchollet/keras) [fchollet/keras](https://github.com/fchollet/keras), 2015.
- <span id="page-8-8"></span>[11] A. Dosovitskiy, P. Fischer, E. Ilg, P. Hausser, C. Hazirbas, V. Golkov, P. van der Smagt, D. Cremers, and T. Brox. Flownet: Learning optical flow with convolutional networks. In *Proceedings of the IEEE International Conference on Computer Vision*, pages 2758 2766, 2015.
- <span id="page-8-11"></span>[12] C. Farabet, C. Couprie, L. Najman, and Y. Le-Cun. Learning hierarchical features for scene labeling. *IEEE transactions on pattern analysis and machine intelligence*, 35(8):19151929, 2013.
- <span id="page-8-20"></span>[13] L. Fiaschi, U. Kothe, R. Nair, and F. A. Hamprecht. ¨ Learning to count with regression forest and structured labels. In *Pattern Recognition (ICPR), 2012 21st International Conference on*, pages 26852688. IEEE, 2012.
- <span id="page-8-1"></span>[14] R. Guerrero-Gomez-Olmedo, B. Torre-Jim ´ enez, S. M. ´ Lopez-Sastre, Roberto Basc ´ on, and D. O ´ noro Rubio. ˜ Extremely overlapping vehicle counting. In *Iberian Conference on Pattern Recognition and Image Analysis (IbPRIA)*, 2015.
- <span id="page-8-3"></span>[15] V. Lempitsky and A. Zisserman. Learning to count objects in images. In *Advances in Neural Information Processing Systems*, pages 13241332, 2010.
- <span id="page-8-12"></span>[16] G. Lin, C. Shen, A. van den Hengel, and I. Reid. Efficient piecewise training of deep structured models for semantic segmentation. In *Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition*, pages 31943203, 2016.
- <span id="page-8-9"></span>[17] H. Noh, S. Hong, and B. Han. Learning deconvolution network for semantic segmentation. In *Proceedings of the IEEE International Conference on Computer Vision*, pages 15201528, 2015.
- <span id="page-8-4"></span>[18] D. Onoro-Rubio and R. J. Lopez-Sastre. Towards ´ perspective-free object counting with deep learning. In *European Conference on Computer Vision*, pages 615629. Springer, 2016.
- <span id="page-8-21"></span>[19] V.-Q. Pham, T. Kozakaya, O. Yamaguchi, and R. Okada. Count forest: Co-voting uncertain number of targets using random forest for crowd density estimation. In *Proceedings of the IEEE International Conference on Computer Vision*, pages 32533261, 2015.
- <span id="page-8-0"></span>[20] D. Ryan, S. Denman, C. Fookes, and S. Sridharan. Crowd counting using multiple local features. In *Digital Image Computing: Techniques and Applications, 2009. DICTA'09.*, pages 8188. IEEE, 2009.
- <span id="page-8-14"></span>[21] S. Segu´ı, O. Pujol, and J. Vitria. Learning to count with deep object features. In *Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops*, pages 9096, 2015.
- <span id="page-8-2"></span>[22] J. Selinummi, O. Yli-Harja, and J. A. Puhakka. Software for quantification of labeled bacteria from digital microscope images by automated image analysis. *Biotechniques*, 39(6):859, 2005.
- <span id="page-8-19"></span>[23] V. A. Sindagi and V. M. Patel. Generating high-quality crowd density maps using contextual pyramid cnns. In *Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition*, pages 18611870, 2017.
- <span id="page-8-15"></span>[24] E. Walach and L. Wolf. Learning to count with cnn boosting. In *European Conference on Computer Vision*, pages 660676. Springer, 2016.
- <span id="page-8-5"></span>[25] F. Yu and V. Koltun. Multi-scale context aggregation by dilated convolutions. *arXiv preprint arXiv:1511.07122*, 2015.
- <span id="page-8-6"></span>[26] F. Yu, V. Koltun, and T. Funkhouser. Dilated residual networks. *arXiv preprint arXiv:1705.09914*, 2017.
- <span id="page-8-7"></span>[27] C. Zhang, H. Li, X. Wang, and X. Yang. Crossscene crowd counting via deep convolutional neural networks. In *Proceedings of the IEEE Conference on*
*Computer Vision and Pattern Recognition*, pages 833 841, 2015.
- <span id="page-9-0"></span>[\[28\]](#page-8-8) Y. Zhang, D. Zhou, S. Chen, S. Gao, and Y. Ma. Single-image crowd counting via multi-column convolutional neural network. In *Proceedings of the IEEE Conference on Computer Vision and Pattern Recogni tion*, pages 589597, 2016.

@ -0,0 +1,955 @@
{
"table_of_contents": [
{
"title": "An Aggregated Multicolumn Dilated Convolution Network\nfor Perspective-Free Counting",
"heading_level": null,
"page_id": 0,
"polygon": [
[
117.5888671875,
105.9219970703125
],
[
477.371826171875,
105.9219970703125
],
[
477.371826171875,
138.201171875
],
[
117.5888671875,
138.201171875
]
]
},
{
"title": "Abstract",
"heading_level": null,
"page_id": 0,
"polygon": [
[
144.1845703125,
232.4891357421875
],
[
190.48028564453125,
232.4891357421875
],
[
190.48028564453125,
244.4443359375
],
[
144.1845703125,
244.4443359375
]
]
},
{
"title": "1. Introduction",
"heading_level": null,
"page_id": 0,
"polygon": [
[
50.016357421875,
512.06591796875
],
[
128.49609375,
512.06591796875
],
[
128.49609375,
524.0211181640625
],
[
50.016357421875,
524.0211181640625
]
]
},
{
"title": "2. Related Work",
"heading_level": null,
"page_id": 0,
"polygon": [
[
307.1953125,
621.7747497558594
],
[
392.0625,
621.7747497558594
],
[
392.0625,
633.7299499511719
],
[
307.1953125,
633.7299499511719
]
]
},
{
"title": "3. Method",
"heading_level": null,
"page_id": 2,
"polygon": [
[
49.4560546875,
371.27313232421875
],
[
101.91387939453125,
371.27313232421875
],
[
101.91387939453125,
383.22833251953125
],
[
49.4560546875,
383.22833251953125
]
]
},
{
"title": "3.1. Dilated Convolutions for Multicolumn Net-\nworks",
"heading_level": null,
"page_id": 2,
"polygon": [
[
49.53076171875,
391.4488220214844
],
[
287.173828125,
391.4488220214844
],
[
287.173828125,
414.3627014160156
],
[
49.53076171875,
414.3627014160156
]
]
},
{
"title": "3.2. Experiments",
"heading_level": null,
"page_id": 3,
"polygon": [
[
49.119873046875,
263.935546875
],
[
128.95028686523438,
263.935546875
],
[
128.95028686523438,
274.936767578125
],
[
49.119873046875,
274.936767578125
]
]
},
{
"title": "3.2.1 UCF50 Crowd Counting",
"heading_level": null,
"page_id": 3,
"polygon": [
[
307.79296875,
339.732421875
],
[
443.4609375,
339.732421875
],
[
443.4609375,
350.13201904296875
],
[
307.79296875,
350.13201904296875
]
]
},
{
"title": "3.2.2 TRANCOS Traffic Counting",
"heading_level": null,
"page_id": 3,
"polygon": [
[
308.689453125,
624.1640625
],
[
461.689453125,
624.1640625
],
[
461.689453125,
634.7828826904297
],
[
308.689453125,
634.7828826904297
]
]
},
{
"title": "3.2.3 UCSD Crowd Counting",
"heading_level": null,
"page_id": 4,
"polygon": [
[
49.38134765625,
314.06341552734375
],
[
182.28515625,
314.06341552734375
],
[
182.28515625,
324.0260009765625
],
[
49.38134765625,
324.0260009765625
]
]
},
{
"title": "3.2.4 WorldExpo '10 Crowd Counting",
"heading_level": null,
"page_id": 4,
"polygon": [
[
308.86199951171875,
259.17828369140625
],
[
477.4889221191406,
259.17828369140625
],
[
477.4889221191406,
269.140869140625
],
[
308.86199951171875,
269.140869140625
]
]
},
{
"title": "4. Results",
"heading_level": null,
"page_id": 5,
"polygon": [
[
49.343994140625,
231.4151611328125
],
[
100.5556640625,
231.4151611328125
],
[
100.5556640625,
243.370361328125
],
[
49.343994140625,
243.370361328125
]
]
},
{
"title": "4.1. UCF Crowd Counting",
"heading_level": null,
"page_id": 5,
"polygon": [
[
49.418701171875,
251.10882568359375
],
[
173.4697265625,
251.10882568359375
],
[
173.4697265625,
262.0677490234375
],
[
49.418701171875,
262.0677490234375
]
]
},
{
"title": "4.2. TRANCOS Traffic Counting",
"heading_level": null,
"page_id": 5,
"polygon": [
[
49.68017578125,
455.92767333984375
],
[
203.80078125,
455.92767333984375
],
[
203.80078125,
466.8865661621094
],
[
49.68017578125,
466.8865661621094
]
]
},
{
"title": "4.3. UCSD Crowd Counting",
"heading_level": null,
"page_id": 5,
"polygon": [
[
49.941650390625,
553.1486358642578
],
[
181.08984375,
553.1486358642578
],
[
181.08984375,
564.1075286865234
],
[
49.941650390625,
564.1075286865234
]
]
},
{
"title": "4.4. WorldExpo '10 Crowd Counting",
"heading_level": null,
"page_id": 5,
"polygon": [
[
308.689453125,
318.3517761230469
],
[
480.814453125,
318.3517761230469
],
[
480.814453125,
329.3106689453125
],
[
308.689453125,
329.3106689453125
]
]
},
{
"title": "4.5. Ablation Studies",
"heading_level": null,
"page_id": 5,
"polygon": [
[
308.689453125,
475.50469970703125
],
[
405.6838684082031,
475.50469970703125
],
[
405.6838684082031,
486.4635925292969
],
[
308.689453125,
486.4635925292969
]
]
},
{
"title": "5. Conclusion",
"heading_level": null,
"page_id": 6,
"polygon": [
[
48.48486328125,
594.6561584472656
],
[
119.20110321044922,
594.6561584472656
],
[
119.20110321044922,
607.1484375
],
[
48.48486328125,
607.1484375
]
]
},
{
"title": "5.1. Summary",
"heading_level": null,
"page_id": 6,
"polygon": [
[
49.194580078125,
619.6148376464844
],
[
115.55853271484375,
619.6148376464844
],
[
115.55853271484375,
630.73828125
],
[
49.194580078125,
630.73828125
]
]
},
{
"title": "5.2. Future Work",
"heading_level": null,
"page_id": 7,
"polygon": [
[
49.269287109375,
611.3048095703125
],
[
130.67086791992188,
611.3048095703125
],
[
130.67086791992188,
622.2637023925781
],
[
49.269287109375,
622.2637023925781
]
]
},
{
"title": "Acknowledgment",
"heading_level": null,
"page_id": 7,
"polygon": [
[
308.86199951171875,
446.23602294921875
],
[
398.337890625,
446.23602294921875
],
[
398.337890625,
458.19122314453125
],
[
308.86199951171875,
458.19122314453125
]
]
},
{
"title": "References",
"heading_level": null,
"page_id": 7,
"polygon": [
[
308.86199951171875,
571.0409851074219
],
[
365.16796875,
571.0409851074219
],
[
365.16796875,
582.9961853027344
],
[
308.86199951171875,
582.9961853027344
]
]
}
],
"page_stats": [
{
"page_id": 0,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
176
],
[
"Line",
84
],
[
"Text",
10
],
[
"SectionHeader",
4
],
[
"PageHeader",
1
],
[
"PageFooter",
1
]
],
"block_metadata": {
"llm_request_count": 0,
"llm_error_count": 0,
"llm_tokens_used": 0
}
},
{
"page_id": 1,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
201
],
[
"Line",
74
],
[
"Text",
5
],
[
"Figure",
1
],
[
"Caption",
1
],
[
"FigureGroup",
1
],
[
"Reference",
1
]
],
"block_metadata": {
"llm_request_count": 0,
"llm_error_count": 0,
"llm_tokens_used": 0
}
},
{
"page_id": 2,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
327
],
[
"Line",
96
],
[
"Text",
10
],
[
"Reference",
3
],
[
"SectionHeader",
2
],
[
"Equation",
2
],
[
"Picture",
1
],
[
"Caption",
1
],
[
"TextInlineMath",
1
],
[
"Footnote",
1
],
[
"PictureGroup",
1
]
],
"block_metadata": {
"llm_request_count": 2,
"llm_error_count": 0,
"llm_tokens_used": 4608
}
},
{
"page_id": 3,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
337
],
[
"Line",
109
],
[
"Text",
8
],
[
"SectionHeader",
3
],
[
"Equation",
1
],
[
"TextInlineMath",
1
],
[
"Reference",
1
]
],
"block_metadata": {
"llm_request_count": 1,
"llm_error_count": 0,
"llm_tokens_used": 3057
}
},
{
"page_id": 4,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
505
],
[
"Line",
121
],
[
"Text",
6
],
[
"TextInlineMath",
6
],
[
"Equation",
2
],
[
"SectionHeader",
2
],
[
"Reference",
1
]
],
"block_metadata": {
"llm_request_count": 2,
"llm_error_count": 0,
"llm_tokens_used": 3814
}
},
{
"page_id": 5,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
332
],
[
"TableCell",
113
],
[
"Line",
100
],
[
"Text",
7
],
[
"SectionHeader",
6
],
[
"Reference",
3
],
[
"Table",
2
],
[
"Caption",
2
],
[
"TableGroup",
2
],
[
"TextInlineMath",
1
]
],
"block_metadata": {
"llm_request_count": 3,
"llm_error_count": 0,
"llm_tokens_used": 7669
}
},
{
"page_id": 6,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
229
],
[
"TableCell",
180
],
[
"Line",
37
],
[
"Caption",
4
],
[
"SectionHeader",
2
],
[
"Text",
2
],
[
"Reference",
2
],
[
"Figure",
1
],
[
"Table",
1
],
[
"FigureGroup",
1
],
[
"TableGroup",
1
]
],
"block_metadata": {
"llm_request_count": 2,
"llm_error_count": 0,
"llm_tokens_used": 7459
}
},
{
"page_id": 7,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
145
],
[
"Line",
68
],
[
"TableCell",
32
],
[
"Text",
5
],
[
"Reference",
5
],
[
"SectionHeader",
3
],
[
"ListItem",
3
],
[
"Caption",
2
],
[
"Figure",
1
],
[
"Table",
1
],
[
"FigureGroup",
1
],
[
"TableGroup",
1
],
[
"ListGroup",
1
]
],
"block_metadata": {
"llm_request_count": 1,
"llm_error_count": 0,
"llm_tokens_used": 2613
}
},
{
"page_id": 8,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
312
],
[
"Line",
101
],
[
"ListItem",
24
],
[
"Reference",
24
],
[
"ListGroup",
2
],
[
"Text",
1
]
],
"block_metadata": {
"llm_request_count": 0,
"llm_error_count": 0,
"llm_tokens_used": 0
}
},
{
"page_id": 9,
"text_extraction_method": "pdftext",
"block_counts": [
[
"Span",
26
],
[
"Line",
7
],
[
"Text",
1
],
[
"ListItem",
1
],
[
"Reference",
1
]
],
"block_metadata": {
"llm_request_count": 0,
"llm_error_count": 0,
"llm_tokens_used": 0
}
}
],
"debug_data_path": "debug_data/multicolcnn"
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

@ -0,0 +1,775 @@
# Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
### William Fedus
liamfedus@google.com
### Barret Zoph
barretzoph@google.com
#### Noam Shazeer
noam@google.com Google, Mountain View, CA 94043, USA
Editor: Alexander Clark
### Abstract
In deep learning, models typically reuse the same parameters for all inputs. Mixture of Experts (MoE) models defy this and instead select different parameters for each incoming example. The result is a sparsely-activated model—with an outrageous number of parameters—but a constant computational cost. However, despite several notable successes of MoE, widespread adoption has been hindered by complexity, communication costs, and training instability. We address these with the introduction of the Switch Transformer. We simplify the MoE routing algorithm and design intuitive improved models with reduced communication and computational costs. Our proposed training techniques mitigate the instabilities, and we show large sparse models may be trained, for the first time, with lower precision (bfloat16) formats. We design models based off T5-Base and T5-Large (Raffel et al., 2019) to obtain up to 7x increases in pre-training speed with the same computational resources. These improvements extend into multilingual settings where we measure gains over the mT5-Base version across all 101 languages. Finally, we advance the current scale of language models by pre-training up to trillion parameter models on the “Colossal Clean Crawled Corpus”, and achieve a 4x speedup over the T5-XXL model.[1](#page-0-0)[2](#page-0-1)
Keywords: mixture-of-experts, natural language processing, sparsity, large-scale machine learning, distributed computing
©2022 William Fedus, Barret Zoph and Noam Shazeer.
<sup></sup>. Equal contribution.
<span id="page-0-0"></span><sup>1.</sup> JAX code for Switch Transformer and all model checkpoints are available at [https://github.com/](https://github.com/google-research/t5x) [google-research/t5x](https://github.com/google-research/t5x)
<span id="page-0-1"></span><sup>2.</sup> Tensorflow code for Switch Transformer is available at [https://github.com/tensorflow/mesh/blob/](https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py) [master/mesh_tensorflow/transformer/moe.py](https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py)
## Contents
| 1 | Introduction | 3 | |
|---|-----------------------------------------------------------|----|----|
| 2 | Switch Transformer | 4 | |
| | 2.1 Simplifying Sparse Routing | 5 | |
| | 2.2 Efficient Sparse Routing | 6 | |
| | 2.3 Putting It All Together: The Switch Transformer | 8 | |
| | 2.4 Improved Training and Fine-Tuning Techniques | 8 | |
| 3 | Scaling Properties | 11 | |
| | 3.1 Scaling Results on a Step-Basis | 12 | |
| | 3.2 Scaling Results on a Time-Basis | 13 | |
| | 3.3 Scaling Versus a Larger Dense Model | 13 | |
| 4 | Downstream Results | 14 | |
| | 4.1 Fine-Tuning | 14 | |
| | 4.2 Distillation | 16 | |
| | 4.3 Multilingual Learning | 17 | |
| 5 | Designing Models with Data, Model, and Expert-Parallelism | 18 | |
| | 5.1 Data Parallelism | 20 | |
| | 5.2 Model Parallelism | 20 | |
| | 5.3 Model and Data Parallelism | 21 | |
| | 5.4 Expert and Data Parallelism | 22 | |
| | 5.5 Expert, Model and Data Parallelism | 22 | |
| | 5.6 Towards Trillion Parameter Models | 22 | |
| 6 | Related Work | 24 | |
| 7 | Discussion | 25 | |
| 8 | Future Work | 26 | |
| 9 | Conclusion | 27 | |
| A | Switch for Attention | 27 | |
| B | Preventing Token Dropping with No-Token-Left-Behind | 29 | |
| C | Encouraging Exploration Across Experts | 29 | |
| D | Switch Transformers in Lower Compute Regimes | | 29 |
| E | Relation of Upstream to Downstream Model Performance | | 32 |
| F | Pseudo Code for Switch Transformers | 33 | |
### <span id="page-2-0"></span>1. Introduction
Large scale training has been an effective path towards flexible and powerful neural language models [(Radford et al.,](#page-37-1) [2018;](#page-37-1) [Kaplan et al.,](#page-36-0) [2020;](#page-36-0) [Brown et al.,](#page-35-0) [2020)](#page-35-0). Simple architectures backed by a generous computational budget, data set size and parameter count—surpass more complicated algorithms [(Sutton,](#page-38-0) [2019)](#page-38-0). An approach followed in [Radford et al.](#page-37-1) [(2018)](#page-37-1); [Raffel et al.](#page-37-0) [(2019)](#page-37-0); [Brown et al.](#page-35-0) [(2020)](#page-35-0) expands the model size of a densely-activated Transformer [(Vaswani et al.,](#page-39-1) [2017)](#page-39-1). While effective, it is also extremely computationally intensive [(Strubell et al.,](#page-38-1) [2019)](#page-38-1). Inspired by the success of model scale, but seeking greater computational efficiency, we instead propose a sparsely-activated expert model: the Switch Transformer. In our case the sparsity comes from activating a subset of the neural network weights for each incoming example.
![](_page_2_Figure_3.jpeg)
Figure 1: Scaling and sample efficiency of Switch Transformers. Left Plot: Scaling properties for increasingly sparse (more experts) Switch Transformers. Right Plot: Negative log perplexity comparing Switch Transformers to T5 [(Raffel et al.,](#page-37-0) [2019)](#page-37-0) models using the same compute budget.
Sparse training is an active area of research and engineering [(Gray et al.,](#page-35-1) [2017;](#page-35-1) [Gale](#page-35-2) [et al.,](#page-35-2) [2020)](#page-35-2), but as of today, machine learning libraries and hardware accelerators still cater to dense matrix multiplications. To have an efficient sparse algorithm, we start with the Mixture-of-Expert (MoE) paradigm [(Jacobs et al.,](#page-36-1) [1991;](#page-36-1) [Jordan and Jacobs,](#page-36-2) [1994;](#page-36-2) [Shazeer](#page-38-2) [et al.,](#page-38-2) [2017)](#page-38-2), and simplify it to yield training stability and computational benefits. MoE models have had notable successes in machine translation [(Shazeer et al.,](#page-38-2) [2017,](#page-38-2) [2018;](#page-38-3) [Lep](#page-37-2)[ikhin et al.,](#page-37-2) [2020)](#page-37-2), however, widespread adoption is hindered by complexity, communication costs, and training instabilities.
We address these issues, and then go beyond translation, to find that these class of algorithms are broadly valuable in natural language. We measure superior scaling on a diverse set of natural language tasks and across three regimes in NLP: pre-training, finetuning and multi-task training. While this work focuses on scale, we also show that the Switch Transformer architecture not only excels in the domain of supercomputers, but is beneficial even with only a few computational cores. Further, our large sparse models can be distilled [(Hinton et al.,](#page-36-3) [2015)](#page-36-3) into small dense versions while preserving 30% of the sparse model quality gain. Our contributions are the following:
- The Switch Transformer architecture, which simplifies and improves over Mixture of Experts.
- Scaling properties and a benchmark against the strongly tuned T5 model [(Raffel et al.,](#page-37-0) [2019)](#page-37-0) where we measure 7x+ pre-training speedups while still using the same FLOPS per token. We further show the improvements hold even with limited computational resources, using as few as two experts.
- Successful distillation of sparse pre-trained and specialized fine-tuned models into small dense models. We reduce the model size by up to 99% while preserving 30% of the quality gains of the large sparse teacher.
- Improved pre-training and fine-tuning techniques: (1) selective precision training that enables training with lower bfloat16 precision (2) an initialization scheme that allows for scaling to a larger number of experts and (3) increased expert regularization that improves sparse model fine-tuning and multi-task training.
- A measurement of the pre-training benefits on multilingual data where we find a universal improvement across all 101 languages and with 91% of languages benefiting from 4x+ speedups over the mT5 baseline [(Xue et al.,](#page-39-2) [2020)](#page-39-2).
- An increase in the scale of neural language models achieved by efficiently combining data, model, and expert-parallelism to create models with up to a trillion parameters. These models improve the pre-training speed of a strongly tuned T5-XXL baseline by 4x.
## <span id="page-3-0"></span>2. Switch Transformer
The guiding design principle for Switch Transformers is to maximize the parameter count of a Transformer model [(Vaswani et al.,](#page-39-1) [2017)](#page-39-1) in a simple and computationally efficient way. The benefit of scale was exhaustively studied in [Kaplan et al.](#page-36-0) [(2020)](#page-36-0) which uncovered powerlaw scaling with model size, data set size and computational budget. Importantly, this work advocates training large models on relatively small amounts of data as the computationally optimal approach.
Heeding these results, we investigate a fourth axis: increase the parameter count while keeping the floating point operations (FLOPs) per example constant. Our hypothesis is that the parameter count, independent of total computation performed, is a separately important axis on which to scale. We achieve this by designing a sparsely activated model that efficiently uses hardware designed for dense matrix multiplications such as GPUs and TPUs. Our work here focuses on TPU architectures, but these class of models may be similarly trained on GPU clusters. In our distributed training setup, our sparsely activated layers split unique weights on different devices. Therefore, the weights of the model increase with the number of devices, all while maintaining a manageable memory and computational footprint on each device.
![](_page_4_Figure_1.jpeg)
Figure 2: Illustration of a Switch Transformer encoder block. We replace the dense feed forward network (FFN) layer present in the Transformer with a sparse Switch FFN layer (light blue). The layer operates independently on the tokens in the sequence. We diagram two tokens ($x_1$ = “More” and $x_2$ = “Parameters” below) being routed (solid lines) across four FFN experts, where the router independently routes each token. The switch FFN layer returns the output of the selected FFN multiplied by the router gate value (dotted-line).
### <span id="page-4-0"></span>2.1 Simplifying Sparse Routing
Mixture of Expert Routing. [Shazeer et al.](#page-38-2) [(2017)](#page-38-2) proposed a natural language Mixtureof-Experts (MoE) layer which takes as an input a token representation $x$ and then routes this to the best determined top-$k$ experts, selected from a set ${E_i(x)\}_{i=1}^N$ of $N$ experts. The router variable $W_r$ produces logits $h(x) = W_r \cdot x$ which are normalized via a softmax distribution over the available $N$ experts at that layer. The gate-value for expert $i$ is given by,
$$p_i(x) = \frac{e^{h(x)_i}}{\sum_j^N e^{h(x)_j}}.\tag{1}$$
The top-$k$ gate values are selected for routing the token $x$. If $\mathcal{T}$ is the set of selected top-$k$ indices then the output computation of the layer is the linearly weighted combination of each expert's computation on the token by the gate value,
<span id="page-4-1"></span>
$$y = \sum_{i \in \mathcal{T}} p_i(x) E_i(x). \tag{2}$$
**Switch Routing:** Rethinking Mixture-of-Experts. [Shazeer et al.](#page-38-2) (2017) conjectured that routing to $k > 1$ experts was necessary in order to have non-trivial gradients to the routing functions. The authors intuited that learning to route would not work without the ability to compare at least two experts. [Ramachandran and Le](#page-37-3) (2018) went further to study the top-$k$ decision and found that higher $k$-values in lower layers in the model were important for models with many routing layers. Contrary to these ideas, we instead use a simplified strategy where we route to only a single expert. We show this simplification preserves model quality, reduces routing computation and performs better. This $k = 1 $routing strategy is later referred to as a Switch layer. Note that for both MoE and Switch Routing, the gate value $p_i(x)$ in Equation [2](#page-4-1) permits differentiability of the router.
The benefits for the Switch layer are three-fold: (1) The router computation is reduced as we are only routing a token to a single expert. (2) The batch size (expert capacity) of each expert can be at least halved since each token is only being routed to a single expert.[3](#page-5-1) (3) The routing implementation is simplified and communication costs are reduced. Figure [3](#page-5-2) shows an example of routing with different expert capacity factors.
![](_page_5_Figure_3.jpeg)
- <span id="page-5-2"></span>Figure 3: Illustration of token routing dynamics. Each expert processes a fixed batch-size of tokens modulated by the capacity factor. Each token is routed to the expert with the highest router probability, but each expert has a fixed batch size of (total_tokens / num_experts) × capacity_factor. If the tokens are unevenly dispatched then certain experts will overflow (denoted by dotted red lines), resulting in these tokens not being processed by this layer. A larger capacity factor alleviates this overflow issue, but also increases computation and communication costs (depicted by padded white/empty slots).
### <span id="page-5-0"></span>2.2 Efficient Sparse Routing
We use Mesh-Tensorflow (MTF) [(Shazeer et al.,](#page-38-3) [2018)](#page-38-3) which is a library, with similar semantics and API to Tensorflow [(Abadi et al.,](#page-35-3) [2016)](#page-35-3) that facilitates efficient distributed data and model parallel architectures. It does so by abstracting the physical set of cores to a logical mesh of processors. Tensors and computations may then be sharded per named dimensions, facilitating easy partitioning of models across dimensions. We design our model with TPUs in mind, which require statically declared sizes. Below we describe our distributed Switch Transformer implementation.
<span id="page-5-1"></span><sup>3.</sup> See Section [2.2](#page-5-0) for a technical description.
Distributed Switch Implementation. All of our tensor shapes are statically determined at compilation time, but our computation is dynamic due to the routing decisions at training and inference. Because of this, one important technical consideration is how to set the expert capacity. The expert capacity—the number of tokens each expert computes—is set by evenly dividing the number of tokens in the batch across the number of experts, and then further expanding by a capacity factor,
$$\text{expected capacity} = \left(\frac{\text{tokens per batch}}{\text{number of experts}}\right) \times \text{capacity factor}.\tag{3}$$
A capacity factor greater than 1.0 creates additional buffer to accommodate for when tokens are not perfectly balanced across experts. If too many tokens are routed to an expert (referred to later as dropped tokens), computation is skipped and the token representation is passed directly to the next layer through the residual connection. Increasing the expert capacity is not without drawbacks, however, since high values will result in wasted computation and memory. This trade-off is explained in Figure [3.](#page-5-2) Empirically we find ensuring lower rates of dropped tokens are important for the scaling of sparse expert-models. Throughout our experiments we didn't notice any dependency on the number of experts for the number of tokens dropped (typically $< 1%$). Using the auxiliary load balancing loss (next section) with a high enough coefficient ensured good load balancing. We study the impact that these design decisions have on model quality and speed in Table [1.](#page-8-0)
A Differentiable Load Balancing Loss. To encourage a balanced load across experts we add an auxiliary loss (Shazeer et al., 2017, 2018; Lepikhin et al., 2020). As in Shazeer et al. (2018); Lepikhin et al. (2020), Switch Transformers simplifies the original design in Shazeer et al. (2017) which had separate load-balancing and importance-weighting losses. For each Switch layer, this auxiliary loss is added to the total model loss during training. Given $N$ experts indexed by $i = 1$ to $N$ and a batch $\mathcal{B}$ with $T$ tokens, the auxiliary loss is computed as the scaled dot-product between vectors $f$ and $P$,
$$\text{loss} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i \tag{4}$$
<span id="page-6-1"></span>where $f_i$ is the fraction of tokens dispatched to expert $i$,
$$f_i = \frac{1}{T} \sum_{x \in \mathcal{B}} \mathbb{1} \{ \operatorname*{argmax} \, p(x) = i \} \tag{5}$$
and $P_i$ is the fraction of the router probability allocated for expert $i$,2
$$P_i = \frac{1}{T} \sum_{x \in \mathcal{B}} p_i(x). \tag{6}$$
Since we seek uniform routing of the batch of tokens across the $N$ experts, we desire both vectors to have values of $1/N$. The auxiliary loss of Equation [4](#page-6-1) encourages uniform routing since it is minimized under a uniform distribution. The objective can also be differentiated as
<span id="page-6-0"></span>A potential source of confusion: $p_i(x)$ is the probability of routing token $x$ to expert $i$. $P_i$ is the probability fraction to expert $i$ across all tokens in the batch $B$.
the $P$-vector is differentiable, but the $f$-vector is not. The final loss is multiplied by expert count $N$ to keep the loss constant as the number of experts varies since under uniform routing $\sum_{i=1}^{N}(f_i$ · Pi) = PN i=1( 1 N · 1 N ) = 1 N . Finally, a hyper-parameter α is a multiplicative coefficient for these auxiliary losses; throughout this work we use an α = 102 which was sufficiently large to ensure load balancing while small enough to not to overwhelm the primary cross-entropy objective. We swept hyper-parameter ranges of $\alpha$ from $10^{-1}$ to $10^{-5}$ in powers of 10 and found $10^{-2}$ balanced load quickly without interfering with training loss.
#### <span id="page-7-0"></span>2.3 Putting It All Together: The Switch Transformer
Our first test of the Switch Transformer starts with pre-training on the "Colossal Clean Crawled Corpus" (C4), introduced in [(Raffel et al.,](#page-37-0) [2019)](#page-37-0). For our pre-training objective, we use a masked language modeling task [(Taylor,](#page-38-4) [1953;](#page-38-4) [Fedus et al.,](#page-35-4) [2018;](#page-35-4) [Devlin et al.,](#page-35-5) [2018)](#page-35-5) where the model is trained to predict missing tokens. In our pre-training setting, as determined in [Raffel et al.](#page-37-0) [(2019)](#page-37-0) to be optimal, we drop out 15% of tokens and then replace the masked sequence with a single sentinel token. To compare our models, we record the negative log perplexity.[4](#page-7-2) Throughout all tables in the paper, ↑ indicates that a higher value for that metric is better and vice-versa for ↓. A comparison of all the models studied in this work are in Table [9.](#page-22-0)
A head-to-head comparison of the Switch Transformer and the MoE Transformer is presented in Table [1.](#page-8-0) Our Switch Transformer model is FLOP-matched to 'T5-Base' [(Raffel](#page-37-0) [et al.,](#page-37-0) [2019)](#page-37-0) (same amount of computation per token is applied). The MoE Transformer, using top-2 routing, has two experts which each apply a separate FFN to each token and thus its FLOPS are larger. All models were trained for the same number of steps on identical hardware. Note that the MoE model going from capacity factor 2.0 to 1.25 actually slows down (840 to 790) in the above experiment setup, which is unexpected.[5](#page-7-3)
We highlight three key findings from Table [1:](#page-8-0) (1) Switch Transformers outperform both carefully tuned dense models and MoE Transformers on a speed-quality basis. For a fixed amount of computation and wall-clock time, Switch Transformers achieve the best result. (2) The Switch Transformer has a smaller computational footprint than the MoE counterpart. If we increase its size to match the training speed of the MoE Transformer, we find this outperforms all MoE and Dense models on a per step basis as well. (3) Switch Transformers perform better at lower capacity factors (1.0, 1.25). Smaller expert capacities are indicative of the scenario in the large model regime where model memory is very scarce and the capacity factor will want to be made as small as possible.
#### <span id="page-7-1"></span>2.4 Improved Training and Fine-Tuning Techniques
Sparse expert models may introduce training difficulties over a vanilla Transformer. Instability can result because of the hard-switching (routing) decisions at each of these layers. Further, low precision formats like bfloat16 [(Wang and Kanwar,](#page-39-3) [2019)](#page-39-3) can exacerbate issues
<span id="page-7-2"></span><sup>4.</sup> We use log base-e for this metric so the units are nats.
<span id="page-7-3"></span><sup>5.</sup> Note that speed measurements are both a function of the algorithm and the implementation details. Switch Transformer reduces the necessary computation relative to MoE (algorithm), but the final speed differences are impacted by low-level optimizations (implementation).
| Model | Capacity<br>Factor | Quality after<br>100k steps (↑)<br>(Neg. Log Perp.) | Time to Quality<br>Threshold (↓)<br>(hours) | Speed (↑)<br>(examples/sec) |
|--------------|--------------------|-----------------------------------------------------|---------------------------------------------|-----------------------------|
| T5-Base | — | -1.731 | Not achieved† | 1600 |
| T5-Large | — | -1.550 | 131.1 | 470 |
| MoE-Base | 2.0 | -1.547 | 68.7 | 840 |
| Switch-Base | 2.0 | -1.554 | 72.8 | 860 |
| MoE-Base | 1.25 | -1.559 | 80.7 | 790 |
| Switch-Base | 1.25 | -1.553 | 65.0 | 910 |
| MoE-Base | 1.0 | -1.572 | 80.1 | 860 |
| Switch-Base | 1.0 | -1.561 | <b>62.8</b> | 1000 |
| Switch-Base+ | 1.0 | <b>-1.534</b> | 67.6 | 780 |
- <span id="page-8-0"></span>Table 1: Benchmarking Switch versus MoE. Head-to-head comparison measuring per step and per time benefits of the Switch Transformer over the MoE Transformer and T5 dense baselines. We measure quality by the negative log perplexity and the time to reach an arbitrary chosen quality threshold of Neg. Log Perp.=-1.50. All MoE and Switch Transformer models use 128 experts, with experts at every other feed-forward layer. For Switch-Base+, we increase the model size until it matches the speed of the MoE model by increasing the model hidden-size from 768 to 896 and the number of heads from 14 to 16. All models are trained with the same amount of computation (32 cores) and on the same hardware (TPUv3). Further note that all our models required pre-training beyond 100k steps to achieve our level threshold of -1.50. † T5-Base did not achieve this negative log perplexity in the 100k steps the models were trained.
in the softmax computation for our router. We describe training difficulties here and the methods we use to overcome them to achieve stable and scalable training.
Selective precision with large sparse models. Model instability hinders the ability to train using efficient bfloat16 precision, and as a result, [Lepikhin et al.](#page-37-2) [(2020)](#page-37-2) trains with float32 precision throughout their MoE Transformer. However, we show that by instead selectively casting to float32 precision within a localized part of the model, stability may be achieved, without incurring expensive communication cost of float32 tensors. This technique is inline with modern mixed precision training strategies where certain parts of the model and gradient updates are done in higher precision [Micikevicius et al.](#page-37-4) [(2017)](#page-37-4). Table [2](#page-9-0) shows that our approach permits nearly equal speed to bfloat16 training while conferring the training stability of float32.
To achieve this, we cast the router input to float32 precision. The router function takes the tokens as input and produces the dispatch and combine tensors used for the selection and recombination of expert computation (refer to Code Block [15](#page-33-0) in the Appendix for details). Importantly, the float32 precision is only used within the body of the router function—on computations local to that device. Because the resulting dispatch and combine tensors are recast to bfloat16 precision at the end of the function, no expensive float32 tensors
| Model<br>(precision) | Quality<br>(Neg. Log Perp.) (↑) | Speed<br>(Examples/sec) (↑) |
|-----------------------------------|---------------------------------|-----------------------------|
| Switch-Base (float32) | -1.718 | 1160 |
| Switch-Base (bfloat16) | -3.780 [diverged] | <b>1390</b> |
| Switch-Base (Selective precision) | <b>-1.716</b> | 1390 |
<span id="page-9-0"></span>Table 2: Selective precision. We cast the local routing operations to float32 while preserving bfloat16 precision elsewhere to stabilize our model while achieving nearly equal speed to (unstable) bfloat16-precision training. We measure the quality of a 32 expert model after a fixed step count early in training its speed performance. For both Switch-Base in float32 and with Selective prevision we notice similar learning dynamics.
are broadcast through all-to-all communication operations, but we still benefit from the increased stability of float32.
Smaller parameter initialization for stability. Appropriate initialization is critical to successful training in deep learning and we especially observe this to be true for Switch Transformer. We initialize our weight matrices by drawing elements from a truncated normal distribution with mean $\mu = 0$ and standard deviation $\sigma = \sqrt{s/n}$ where $s$ is a scale hyper-parameter and $n$ is the number of input units in the weight tensor (e.g. fan-in).[6](#page-9-1)
As an additional remedy to the instability, we recommend reducing the default Transformer initialization scale $s = 1.0$ by a factor of 10. This both improves quality and reduces the likelihood of destabilized training in our experiments. Table [3](#page-9-2) measures the improvement of the model quality and reduction of the variance early in training. We find that
| Model (Initialization scale) | Average Quality<br>(Neg. Log Perp.) | Std. Dev. of Quality<br>(Neg. Log Perp.) |
|------------------------------|-------------------------------------|------------------------------------------|
| Switch-Base (0.1x-init) | <b>-2.72</b> | <b>0.01</b> |
| Switch-Base (1.0x-init) | -3.60 | 0.68 |
- <span id="page-9-2"></span>Table 3: Reduced initialization scale improves stability. Reducing the initialization scale results in better model quality and more stable training of Switch Transformer. Here we record the average and standard deviation of model quality, measured by the negative log perplexity, of a 32 expert model after 3.5k steps (3 random seeds each).
the average model quality, as measured by the Neg. Log Perp., is dramatically improved and there is a far reduced variance across runs. Further, this same initialization scheme is broadly effective for models spanning several orders of magnitude. We use the same approach to stably train models as small as our 223M parameter baseline to enormous models in excess of one trillion parameters.
<span id="page-9-1"></span><sup>6.</sup> Values greater than two standard deviations from the mean are resampled.
Regularizing large sparse models. Our paper considers the common NLP approach of pre-training on a large corpus followed by fine-tuning on smaller downstream tasks such as summarization or question answering. One issue that naturally arises is overfitting since many fine-tuning tasks have very few examples. During fine-tuning of standard Transformers, [Raffel et al.](#page-37-0) [(2019)](#page-37-0) use dropout [(Srivastava et al.,](#page-38-5) [2014)](#page-38-5) at each layer to prevent overfitting. Our Switch Transformers have significantly more parameters than the FLOP matched dense baseline, which can lead to more severe overfitting on these smaller downstream tasks.
| Model (dropout) | GLUE | CNNDM | SQuAD | SuperGLUE |
|-----------------------------|-------------|-------------|-------------|-------------|
| T5-Base (d=0.1) | 82.9 | <b>19.6</b> | 83.5 | 72.4 |
| Switch-Base (d=0.1) | 84.7 | 19.1 | <b>83.7</b> | <b>73.0</b> |
| Switch-Base (d=0.2) | 84.4 | 19.2 | <b>83.9</b> | <b>73.2</b> |
| Switch-Base (d=0.3) | 83.9 | 19.6 | 83.4 | 70.7 |
| Switch-Base (d=0.1, ed=0.4) | <b>85.2</b> | <b>19.6</b> | <b>83.7</b> | <b>73.0</b> |
<span id="page-10-1"></span>Table 4: Fine-tuning regularization results. A sweep of dropout rates while fine-tuning Switch Transformer models pre-trained on 34B tokens of the C4 data set (higher numbers are better). We observe that using a lower standard dropout rate at all non-expert layer, with a much larger dropout rate on the expert feed-forward layers, to perform the best.
We thus propose a simple way to alleviate this issue during fine-tuning: increase the dropout inside the experts, which we name as expert dropout. During fine-tuning we simply increase the dropout rate by a significant amount only at the interim feed-forward computation at each expert layer. Table [4](#page-10-1) has the results for our expert dropout protocol. We observe that simply increasing the dropout across all layers leads to worse performance. However, setting a smaller dropout rate (0.1) at non-expert layers and a much larger dropout rate (0.4) at expert layers leads to performance improvements on four smaller downstream tasks.
### <span id="page-10-0"></span>3. Scaling Properties
We present a study of the scaling properties of the Switch Transformer architecture during pre-training. Per [Kaplan et al.](#page-36-0) [(2020)](#page-36-0), we consider a regime where the model is not bottlenecked by either the computational budget or amount of data. To avoid the data bottleneck, we use the large C4 corpus with over 180B target tokens [(Raffel et al.,](#page-37-0) [2019)](#page-37-0) and we train until diminishing returns are observed.
The number of experts is the most efficient dimension for scaling our model. Increasing the experts keeps the computational cost approximately fixed since the model only selects one expert per token, regardless of the number of experts to choose from. The router must compute a probability distribution over more experts, however, this is a lightweight computation of cost $O(d_{model} \times \text{num experts})$ where $d_{model}$ is the embedding dimension of tokens passed between the layers. In this section, we consider the scaling properties on a step-basis and a time-basis with a fixed computational budget.
#### <span id="page-11-0"></span>3.1 Scaling Results on a Step-Basis
Figure [4](#page-11-1) demonstrates consistent scaling benefits with the number of experts when training all models for a fixed number of steps. We observe a clear trend: when keeping the FLOPS per token fixed, having more parameters (experts) speeds up training. The left Figure demonstrates consistent scaling properties (with fixed FLOPS per token) between sparse model parameters and test loss. This reveals the advantage of scaling along this additional axis of sparse model parameters. Our right Figure measures sample efficiency of a dense model variant and four FLOP-matched sparse variants. We find that increasing the number of experts leads to more sample efficient models. Our Switch-Base 64 expert model achieves the same performance of the T5-Base model at step 60k at step 450k, which is a 7.5x speedup in terms of step time. In addition, consistent with the findings of [Kaplan et al.](#page-36-0) [(2020)](#page-36-0), we find that larger models are also more sample efficient—learning more quickly for a fixed number of observed tokens.
![](_page_11_Figure_4.jpeg)
- <span id="page-11-1"></span>Figure 4: Scaling properties of the Switch Transformer. Left Plot: We measure the quality improvement, as measured by perplexity, as the parameters increase by scaling the number of experts. The top-left point corresponds to the T5-Base model with 223M parameters. Moving from top-left to bottom-right, we double the number of experts from 2, 4, 8 and so on until the bottom-right point of a 256 expert model with 14.7B parameters. Despite all models using an equal computational budget, we observe consistent improvements scaling the number of experts. Right Plot: Negative log perplexity per step sweeping over the number of experts. The dense baseline is shown with the purple line and we note improved sample efficiency of our Switch-Base models.
#### <span id="page-12-0"></span>3.2 Scaling Results on a Time-Basis
Figure [4](#page-11-1) demonstrates that on a step basis, as we increase the number of experts, the performance consistently improves. While our models have roughly the same amount of FLOPS per token as the baseline, our Switch Transformers incurs additional communication costs across devices as well as the extra computation of the routing mechanism. Therefore, the increased sample efficiency observed on a step-basis doesn't necessarily translate to a better model quality as measured by wall-clock. This raises the question:
For a fixed training duration and computational budget, should one train a dense or a sparse model?
![](_page_12_Figure_4.jpeg)
<span id="page-12-2"></span>Figure 5: Speed advantage of Switch Transformer. All models trained on 32 TPUv3 cores with equal FLOPs per example. For a fixed amount of computation and training time, Switch Transformers significantly outperform the dense Transformer baseline. Our 64 expert Switch-Base model achieves the same quality in one-seventh the time of the T5-Base and continues to improve.
Figures [5](#page-12-2) and [6](#page-13-2) address this question. Figure [5](#page-12-2) measures the pre-training model quality as a function of time. For a fixed training duration and computational budget, Switch Transformers yield a substantial speed-up. In this setting, our Switch-Base 64 expert model trains in one-seventh the time that it would take the T5-Base to get similar perplexity.
#### <span id="page-12-1"></span>3.3 Scaling Versus a Larger Dense Model
The above analysis shows that a computationally-matched dense model is outpaced by its Switch counterpart. Figure [6](#page-13-2) considers a different scenario: what if we instead had allocated our resources to a larger dense model? We do so now, measuring Switch-Base against the next strong baseline, T5-Large. But despite T5-Large applying 3.5x more FLOPs per token, Switch-Base is still more sample efficient and yields a 2.5x speedup. Furthermore, more gains can be had simply by designing a new, larger sparse version, Switch-Large, which is FLOP-matched to T5-Large. We do this and demonstrate superior scaling and fine-tuning in the following section.
![](_page_13_Figure_2.jpeg)
<span id="page-13-2"></span>Figure 6: Scaling Transformer models with Switch layers or with standard dense model scaling. Left Plot: Switch-Base is more sample efficient than both the T5-Base, and T5-Large variant, which applies 3.5x more FLOPS per token. Right Plot: As before, on a wall-clock basis, we find that Switch-Base is still faster, and yields a 2.5x speedup over T5-Large.
### <span id="page-13-0"></span>4. Downstream Results
Section [3](#page-10-0) demonstrated the superior scaling properties while pre-training, but we now validate that these gains translate to improved language learning abilities on downstream tasks. We begin by fine-tuning on a diverse set of NLP tasks. Next we study reducing the memory footprint of our sparse models by over 90% by distilling into small—and easily deployed—dense baselines. Finally, we conclude this section measuring the improvements in a multi-task, multilingual setting, where we show that Switch Transformers are strong multi-task learners, improving over the multilingual T5-base model across all 101 languages.
### <span id="page-13-1"></span>4.1 Fine-Tuning
Baseline and Switch models used for fine-tuning. Our baselines are the highly-tuned 223M parameter T5-Base model and the 739M parameter T5-Large model [(Raffel et al.,](#page-37-0) [2019)](#page-37-0). For both versions, we design a FLOP-matched Switch Transformer, with many more parameters, which is summarized in Table 9.[7](#page-13-3) Our baselines differ slightly from those in [Raffel et al.](#page-37-0) [(2019)](#page-37-0) because we pre-train on an improved C4 corpus which removes intraexample text duplication and thus increases the efficacy as a pre-training task [Lee et al.](#page-37-5)
<span id="page-13-3"></span><sup>7.</sup> FLOPS are calculated for the forward pass as done in [Kaplan et al.](#page-36-0) [(2020)](#page-36-0).
[(2021)](#page-37-5). In our protocol we pre-train with $2^{20}$ (1,048,576) tokens per batch for 550k steps amounting to 576B total tokens. We then fine-tune across a diverse set of tasks using a dropout rate of 0.1 for all layers except the Switch layers, which use a dropout rate of 0.4 (see Table [4](#page-10-1)). We fine-tune using a batch-size of 1M for 16k steps and for each task, we evaluate model quality every 200-steps and report the peak performance as computed on the validation set.
Fine-tuning tasks and data sets. We select tasks probing language capabilities including question answering, summarization and knowledge about the world. The language benchmarks GLUE [(Wang et al.,](#page-39-4) [2018)](#page-39-4) and SuperGLUE [(Wang et al.,](#page-39-5) [2019)](#page-39-5) are handled as composite mixtures with all the tasks blended in proportion to the amount of tokens present in each. These benchmarks consist of tasks requiring sentiment analysis (SST-2), word sense disambiguation (WIC), sentence similarty (MRPC, STS-B, QQP), natural language inference (MNLI, QNLI, RTE, CB), question answering (MultiRC, RECORD, BoolQ), coreference resolution (WNLI, WSC) and sentence completion (COPA) and sentence acceptability (CoLA). The CNNDM [(Hermann et al.,](#page-36-4) [2015)](#page-36-4) and BBC XSum [(Narayan](#page-37-6) [et al.,](#page-37-6) [2018)](#page-37-6) data sets are used to measure the ability to summarize articles. Question answering is probed with the SQuAD data set [(Rajpurkar et al.,](#page-37-7) [2016)](#page-37-7) and the ARC Reasoning Challenge [(Clark et al.,](#page-35-6) [2018)](#page-35-6). And as in [Roberts et al.](#page-38-6) [(2020)](#page-38-6), we evaluate the knowledge of our models by fine-tuning on three closed-book question answering data sets: Natural Questions [(Kwiatkowski et al.,](#page-36-5) [2019)](#page-36-5), Web Questions [(Berant et al.,](#page-35-7) [2013)](#page-35-7) and Trivia QA [(Joshi et al.,](#page-36-6) [2017)](#page-36-6). Closed-book refers to questions posed with no supplemental reference or context material. To gauge the model's common sense reasoning we evaluate it on the Winogrande Schema Challenge [(Sakaguchi et al.,](#page-38-7) [2020)](#page-38-7). And finally, we test our model's natural language inference capabilities on the Adversarial NLI Benchmark [(Nie et al.,](#page-37-8) [2019)](#page-37-8).
Fine-tuning metrics. The following evaluation metrics are used throughout the paper: We report the average scores across all subtasks for GLUE and SuperGLUE. The Rouge-2 metric is used both the CNNDM and XSum. In SQuAD and the closed book tasks (Web, Natural, and Trivia Questions) we report the percentage of answers exactly matching the target (refer to [Roberts et al.](#page-38-6) [(2020)](#page-38-6) for further details and deficiency of this measure). Finally, in ARC Easy, ARC Challenge, ANLI, and Winogrande we report the accuracy of the generated responses.
Fine-tuning results. We observe significant downstream improvements across many natural language tasks. Notable improvements come from SuperGLUE, where we find FLOP-matched Switch variants improve by 4.4 and 2 percentage points over the T5-Base and T5-Large baselines, respectively as well as large improvements in Winogrande, closed book Trivia QA, and XSum.[8](#page-14-0) In our fine-tuning study, the only tasks where we do not observe gains are on the AI2 Reasoning Challenge (ARC) data sets where the T5-Base outperforms Switch-Base on the challenge data set and T5-Large outperforms Switch-Large on the easy data set. Taken as a whole, we observe significant improvements spanning both reasoning and knowledge-heavy tasks. This validates our architecture, not just as one that pre-trains well, but can translate quality improvements to downstream tasks via fine-tuning.
<span id="page-14-0"></span>8. Our T5 and Switch models were pre-trained with $2^{20}$ tokens per batch for 550k steps on a revised C4 data set for fair comparisons.
| Model | GLUE | SQuAD | SuperGLUE | Winogrande (XL) |
|--------------|-------------|---------------|--------------|-----------------|
| T5-Base | 84.3 | 85.5 | 75.1 | 66.6 |
| Switch-Base | <b>86.7</b> | <b>87.2</b> | <b>79.5</b> | <b>73.3</b> |
| T5-Large | 87.8 | 88.1 | 82.7 | 79.1 |
| Switch-Large | <b>88.5</b> | <b>88.6</b> | <b>84.7</b> | <b>83.0</b> |
| Model | XSum | ANLI (R3) | ARC Easy | ARC Chal. |
| T5-Base | 18.7 | 51.8 | 56.7 | <b>35.5</b> |
| Switch-Base | <b>20.3</b> | <b>54.0</b> | <b>61.3</b> | 32.8 |
| T5-Large | 20.9 | 56.6 | <b>68.8</b> | <b>35.5</b> |
| Switch-Large | <b>22.3</b> | <b>58.6</b> | 66.0 | <b>35.5</b> |
| Model | CB Web QA | CB Natural QA | CB Trivia QA | |
| T5-Base | 26.6 | 25.8 | 24.5 | |
| Switch-Base | <b>27.4</b> | <b>26.8</b> | <b>30.7</b> | |
| T5-Large | 27.7 | 27.6 | 29.5 | |
| Switch-Large | <b>31.3</b> | <b>29.5</b> | <b>36.9</b> | |
Table 5: Fine-tuning results. Fine-tuning results of T5 baselines and Switch models across a diverse set of natural language tests (validation sets; higher numbers are better). We compare FLOP-matched Switch models to the T5-Base and T5-Large baselines. For most tasks considered, we find significant improvements of the Switchvariants. We observe gains across both model sizes and across both reasoning and knowledge-heavy language tasks.
#### <span id="page-15-0"></span>4.2 Distillation
Deploying massive neural networks with billions, or trillions, of parameters is inconvenient. To alleviate this, we study distilling [(Hinton et al.,](#page-36-3) [2015)](#page-36-3) large sparse models into small dense models. Future work could additionally study distilling large models into smaller sparse models.
Distillation techniques. In Table [6](#page-16-1) we study a variety of distillation techniques. These techniques are built off of [Sanh et al.](#page-38-8) [(2019)](#page-38-8), who study distillation methods for BERT models. We find that initializing the dense model with the non-expert weights yields a modest improvement. This is possible since all models are FLOP matched, so non-expert layers will have the same dimensions. Since expert layers are usually only added at every or every other FFN layer in a Transformer, this allows for many of the weights to be initialized with trained parameters. Furthermore, we observe a distillation improvement using a mixture of 0.25 for the teacher probabilities and 0.75 for the ground truth label. By combining both techniques we preserve $\approx$ 30% of the quality gains from the larger sparse models with only $\approx$ 1/20th of the parameters. The quality gain refers to the percent of
| Technique | Parameters | Quality (↑) |
|-------------------------------------------|------------|--------------|
| T5-Base | 223M | -1.636 |
| Switch-Base | 3,800M | -1.444 |
| Distillation | 223M | (3%) -1.631 |
| + Init. non-expert weights from teacher | 223M | (20%) -1.598 |
| + 0.75 mix of hard and soft loss | 223M | (29%) -1.580 |
| Initialization Baseline (no distillation) | | |
| Init. non-expert weights from teacher | 223M | -1.639 |
the quality difference between Switch-Base (Teacher) and T5-Base (Student). Therefore, a quality gain of 100% implies the Student equals the performance of the Teacher.
- <span id="page-16-1"></span>Table 6: Distilling Switch Transformers for Language Modeling. Initializing T5-Base with the non-expert weights from Switch-Base and using a loss from a mixture of teacher and ground-truth labels obtains the best performance. We can distill 30% of the performance improvement of a large sparse model with 100x more parameters back into a small dense model. For a final baseline, we find no improvement of T5-Base initialized with the expert weights, but trained normally without distillation.
Achievable compression rates. Using our best distillation technique described in Table [6,](#page-16-1) we distill a wide variety of sparse models into dense models. We distill Switch-Base versions, sweeping over an increasing number of experts, which corresponds to varying between 1.1B to 14.7B parameters. Through distillation, we can preserve 37% of the quality gain of the 1.1B parameter model while compressing 82%. At the extreme, where we compress the model 99%, we are still able to maintain 28% of the teacher's model quality improvement.
Distilling a fine-tuned model. We conclude this with a study of distilling a finetuned sparse model into a dense model. Table [8](#page-17-1) shows results of distilling a 7.4B parameter Switch-Base model, fine-tuned on the SuperGLUE task, into the 223M T5-Base. Similar to our pre-training results, we find we are able to preserve 30% of the gains of the sparse model when distilling into a FLOP matched dense variant. One potential future avenue, not considered here, may examine the specific experts being used for fine-tuning tasks and extracting them to achieve better model compression.
#### <span id="page-16-0"></span>4.3 Multilingual Learning
In our final set of downstream experiments, we measure the model quality and speed tradeoffs while pre-training on a mixture of 101 different languages. We build and benchmark off the recent work of mT5 [(Xue et al.,](#page-39-2) [2020)](#page-39-2), a multilingual extension to T5. We pre-train on the multilingual variant of the Common Crawl data set (mC4) spanning 101 languages introduced in mT5, but due to script variants within certain languages, the mixture contains 107 tasks.
In Figure [7](#page-18-0) we plot the quality improvement in negative log perplexity for all languages of a FLOP-matched Switch model, mSwitch-Base to the T5 base variant, mT5-Base. After
| | Dense | Sparse | | | | |
|--------------------------------|--------|--------|--------|--------|--------|--------|
| Parameters | 223M | 1.1B | 2.0B | 3.8B | 7.4B | 14.7B |
| Pre-trained Neg. Log Perp. (↑) | -1.636 | -1.505 | -1.474 | -1.444 | -1.432 | -1.427 |
| Distilled Neg. Log Perp. (↑) | — | -1.587 | -1.585 | -1.579 | -1.582 | -1.578 |
| Percent of Teacher Performance | — | 37% | 32% | 30 % | 27 % | 28 % |
| Compression Percent | — | 82 % | 90 % | 95 % | 97 % | 99 % |
- Table 7: Distillation compression rates. We measure the quality when distilling large sparse models into a dense baseline. Our baseline, T5-Base, has a -1.636 Neg. Log Perp. quality. In the right columns, we then distill increasingly large sparse models into this same architecture. Through a combination of weight-initialization and a mixture of hard and soft losses, we can shrink our sparse teachers by 95%+ while preserving 30% of the quality gain. However, for significantly better and larger pre-trained teachers, we expect larger student models would be necessary to achieve these compression rates.
| Model | Parameters | FLOPS | SuperGLUE (↑) |
|-------------------|------------|-------|---------------|
| T5-Base | 223M | 124B | 74.6 |
| Switch-Base | 7410M | 124B | 81.3 |
| Distilled T5-Base | 223M | 124B | (30%) 76.6 |
- <span id="page-17-1"></span>Table 8: Distilling a fine-tuned SuperGLUE model. We distill a Switch-Base model finetuned on the SuperGLUE tasks into a T5-Base model. We observe that on smaller data sets our large sparse model can be an effective teacher for distillation. We find that we again achieve 30% of the teacher's performance on a 97% compressed model.
pre-training both versions for 1M steps, we find that on all 101 languages considered, Switch Transformer increases the final negative log perplexity over the baseline. In Figure 8, we present a different view and now histogram the per step speed-up of using Switch Transformer over the mT5-Base.[9](#page-17-2) We find a mean speed-up over mT5-Base of 5x and that 91% of languages achieve at least a 4x speedup. This presents evidence that Switch Transformers are effective multi-task and multi-lingual learners.
### <span id="page-17-0"></span>5. Designing Models with Data, Model, and Expert-Parallelism
Arbitrarily increasing the number of experts is subject to diminishing returns (Figure [4)](#page-11-1). Here we describe complementary scaling strategies. The common way to scale a Transformer is to increase dimensions in tandem, like $d_{model}$ or $d_{ff}$. This increases both the parameters
<span id="page-17-2"></span><sup>9.</sup> The speedup on a step basis is computed as the ratio of the number of steps for the baseline divided by the number of steps required by our model to reach that same quality.
![](_page_18_Figure_1.jpeg)
<span id="page-18-0"></span>Figure 7: Multilingual pre-training on 101 languages. Improvements of Switch T5 Base model over dense baseline when multi-task training on 101 languages. We observe Switch Transformers to do quite well in the multi-task training setup and yield improvements on all 101 languages.
![](_page_18_Figure_3.jpeg)
<span id="page-18-1"></span>Figure 8: Multilingual pre-training on 101 languages. We histogram for each language, the step speedup of Switch Transformers over the FLOP matched T5 dense baseline to reach the same quality. Over all 101 languages, we achieve a mean step speedup over mT5-Base of 5x and, for 91% of languages, we record a 4x, or greater, speedup to reach the final perplexity of mT5-Base.
and computation performed and is ultimately limited by the memory per accelerator. Once it exceeds the size of the accelerator's memory, single program multiple data (SPMD) modelparallelism can be employed. This section studies the trade-offs of combining data, model, and expert-parallelism.
Reviewing the Feed-Forward Network (FFN) Layer. We use the FFN layer as an example of how data, model and expert-parallelism works in Mesh TensorFlow [(Shazeer](#page-38-3) [et al.,](#page-38-3) [2018)](#page-38-3) and review it briefly here. We assume $B$ tokens in the batch, each of dimension $d_{model}$. Both the input $(x)$ and output $(y)$ of the FFN are of size $[B, d_{model}]$ and the intermediate $(h)$ is of size $[B, d_{ff}]$ where $d_{ff}$ is typically several times larger than $d_{model}$. In the FFN, the intermediate is $h = xW_{in}$ and then the output of the layer is $y = ReLU(h)W_{out}$. Thus $W_{in}$ and $W_{out}$ are applied independently to each token and have sizes $[d_{model}, d_{ff}] $and $[d_{ff}, d_{model}]$.
We describe two aspects of partitioning: how the weights and batches of data divide over cores, depicted in Figure [9.](#page-20-1) We denote all cores available as $N$ which Mesh Tensorflow may then remap into a logical multidimensional mesh of processors. Here we create a two-dimensional logical mesh, with one dimension representing the number of ways for data-parallel sharding $(n)$ and the other, the model-parallel sharding $(m)$. The total cores must equal the ways to shard across both data and model-parallelism, e.g. $N = n \times m$. To shard the layer across cores, the tensors containing that batch of $B$ tokens are sharded across $n$ data-parallel cores, so each core contains $B/n$ tokens. Tensors and variables with $d_{ff}$ are then sharded across $m$ model-parallel cores. For the variants with experts-layers, we consider $E$ experts, each of which can process up to $C$ tokens.
| Term | Description |
|------|-------------------------------------------------|
| B | Number of tokens in the batch. |
| N | Number of total cores. |
| n | Number of ways for data-parallelism sharding. |
| m | Number of ways for model-parallelism sharding. |
| E | Number of experts in Switch layers. |
| C | Expert capacity, the batch size of each expert. |
#### <span id="page-19-0"></span>5.1 Data Parallelism
When training data parallel models, which is the standard for distributed training, then all cores are allocated to the data-parallel dimension or $n = N, m = 1$. This has the advantage that no communication is needed until the entire forward and backward pass is finished and the gradients need to be then aggregated across all cores. This corresponds to the left-most column of Figure [9.](#page-20-1)
#### <span id="page-19-1"></span>5.2 Model Parallelism
We now consider a scenario where all cores are allocated exclusively to the model-parallel dimension and so $n = 1, m = N$. Now all cores must keep the full $B$ tokens and each core will contain a unique slice of the weights. For each forward and backward pass, a communication cost is now incurred. Each core sends a tensor of $[B, d_{model}]$ to compute the second matrix multiplication $ReLU(h)W_{out}$ because the $d_{ff}$ dimension is partitioned and must be summed over. As a general rule, whenever a dimension that is partitioned across cores must be summed, then an all-reduce operation is added for both the forward and backward pass. This contrasts with pure data parallelism where an all-reduce only occurs at the end of the entire forward and backward pass.
![](_page_20_Figure_1.jpeg)
#### **How the** *model weights* **are split over cores**
#### **How the** *data* **is split over cores**
![](_page_20_Figure_4.jpeg)
- <span id="page-20-1"></span>Figure 9: Data and weight partitioning strategies. Each 4×4 dotted-line grid represents 16 cores and the shaded squares are the data contained on that core (either model weights or batch of tokens). We illustrate both how the model weights and the data tensors are split for each strategy. First Row: illustration of how model weights are split across the cores. Shapes of different sizes in this row represent larger weight matrices in the Feed Forward Network (FFN) layers (e.g larger $d_{ff}$ sizes). Each color of the shaded squares identifies a unique weight matrix. The number of parameters per core is fixed, but larger weight matrices will apply more computation to each token. Second Row: illustration of how the data batch is split across cores. Each core holds the same number of tokens which maintains a fixed memory usage across all strategies. The partitioning strategies have different properties of allowing each core to either have the same tokens or different tokens across cores, which is what the different colors symbolize.
#### <span id="page-20-0"></span>5.3 Model and Data Parallelism
It is common to mix both model and data parallelism for large scale models, which was done in the largest T5 models [(Raffel et al.,](#page-37-0) [2019;](#page-37-0) [Xue et al.,](#page-39-2) [2020)](#page-39-2) and in GPT-3 [(Brown et al.,](#page-35-0) [2020)](#page-35-0). With a total of $N = n \times m$ cores, now each core will be responsible for $B/n$ tokens and $d_{ff}/m$ of both the weights and intermediate activation. In the forward and backward pass each core communicates a tensor of size $[B/n, d_{model}]$ in an all-reduce operation.
#### <span id="page-21-0"></span>5.4 Expert and Data Parallelism
Next we describe the partitioning strategy for expert and data parallelism. Switch Transformers will allocate all of their cores to the data partitioning dimension $n$, which will also correspond to the number of experts in the model. For each token per core a router locally computes assignments to the experts. The output is a binary matrix of size $[n, B/n, E$, $C]$ which is partitioned across the first dimension and determines expert assignment. This binary matrix is then used to do a gather via matrix multiplication with the input tensor of $[n, B/n, d_{model}]$.
$$\text{einsum}([n, B/n, d_{model}], [n, B/n, E, C], \text{dimension} = [B/n]) \tag{7}$$
resulting in the final tensor of shape $[n, E, C, d_{model}]$, which is sharded across the first dimension. Because each core has its own expert, we do an all-to-all communication of size $[E, C, d_{model}]$ to now shard the $E$ dimension instead of the $n$-dimension. There are additional communication costs of bfloat16 tensors of size $E \times C \times d_{model}$ in the forward pass to analogously receive the tokens from each expert located on different cores. See Appendix [F](#page-32-0) for a detailed analysis of the expert partitioning code.
#### <span id="page-21-1"></span>5.5 Expert, Model and Data Parallelism
In the design of our best model, we seek to balance the FLOPS per token and the parameter count. When we scale the number of experts, we increase the number of parameters, but do not change the FLOPS per token. In order to increase FLOPS, we must also increase the $d_{ff}$ dimension (which also increases parameters, but at a slower rate). This presents a trade-off: as we increase $d_{ff}$ we will run out of memory per core, which then necessitates increasing $m$. But since we have a fixed number of cores $N$, and $N = n \times m$, we must decrease $n$, which forces use of a smaller batch-size (in order to hold tokens per core constant).
When combining both model and expert-parallelism, we will have all-to-all communication costs from routing the tokens to the correct experts along with the internal all-reduce communications from the model parallelism. Balancing the FLOPS, communication costs and memory per core becomes quite complex when combining all three methods where the best mapping is empirically determined. See our further analysis in section [5.6](#page-21-2) for how the number of experts effects the downstream performance as well.
#### <span id="page-21-2"></span>5.6 Towards Trillion Parameter Models
Combining expert, model and data parallelism, we design two large Switch Transformer models, one with 395 billion and 1.6 trillion parameters, respectively. We study how these models perform on both up-stream pre-training as language models and their downstream fine-tuning performance. The parameters, FLOPs per sequence and hyper-parameters of the two different models are listed below in Table 9. Standard hyper-parameters of the Transformer, including $d_{model}$, $d_{ff}$, $d_{kv}$, number of heads and number of layers are described, as well as a less common feature, $FFN_{GEGLU}$, which refers to a variation of the FFN layer where the expansion matrix is substituted with two sets of weights which are non-linearly combined (Shazeer, 2020).
The Switch-C model is designed using only expert-parallelism, and no model-parallelism, as described earlier in Section [5.4.](#page-21-0) As a result, the hyper-parameters controlling the width,
| Model | Parameters | FLOPs/seq | dmodel | F F NGEGLU | df f | dkv | Num. Heads |
|--------------|--------------|-------------|-------------|----------------------|-----------------------|-----|------------|
| T5-Base | 0.2B | 124B | 768 | X | 2048 | 64 | 12 |
| T5-Large | 0.7B | 425B | 1024 | X | 2816 | 64 | 16 |
| T5-XXL | 11B | 6.3T | 4096 | X | 10240 | 64 | 64 |
| Switch-Base | 7B | 124B | 768 | X | 2048 | 64 | 12 |
| Switch-Large | 26B | 425B | 1024 | X | 2816 | 64 | 16 |
| Switch-XXL | 395B | 6.3T | 4096 | X | 10240 | 64 | 64 |
| Switch-C | 1571B | 890B | 2080 | | 6144 | 64 | 32 |
| | | | | | | | |
| Model | Expert Freq. | Num. Layers | Num Experts | Neg. Log Perp. @250k | Neg. Log Perp. @ 500k | | |
| T5-Base | | 12 | | -1.599 | -1.556 | | |
| T5-Large | | 24 | | -1.402 | -1.350 | | |
| T5-XXL | | 24 | | -1.147 | -1.095 | | |
| Switch-Base | 1/2 | 12 | 128 | -1.370 | -1.306 | | |
| Switch-Large | 1/2 | 24 | 128 | -1.248 | -1.177 | | |
| Switch-XXL | 1/2 | 24 | 64 | -1.086 | -1.008 | | |
| Switch-C | 1 | 15 | 2048 | -1.096 | -1.043 | | |
- <span id="page-22-0"></span>Table 9: Switch model design and pre-training performance. We compare the hyperparameters and pre-training performance of the T5 models to our Switch Transformer variants. The last two columns record the pre-training model quality on the C4 data set after 250k and 500k steps, respectively. We observe that the Switch-C Transformer variant is 4x faster to a fixed perplexity (with the same compute budget) than the T5-XXL model, with the gap increasing as training progresses.
depth, number of heads, and so on, are all much smaller than the T5-XXL model. In contrast, the Switch-XXL is FLOP-matched to the T5-XXL model, which allows for larger dimensions of the hyper-parameters, but at the expense of additional communication costs induced by model-parallelism (see Section [5.5](#page-21-1) for more details).
Sample efficiency versus T5-XXL. In the final two columns of Table [9](#page-22-0) we record the negative log perplexity on the C4 corpus after 250k and 500k steps, respectively. After 250k steps, we find both Switch Transformer variants to improve over the T5-XXL version's negative log perplexity by over 0.061.[10](#page-22-1) To contextualize the significance of a gap of 0.061, we note that the T5-XXL model had to train for an additional 250k steps to increase 0.052. The gap continues to increase with additional training, with the Switch-XXL model out-performing the T5-XXL by 0.087 by 500k steps.
Training instability. However, as described in the introduction, large sparse models can be unstable, and as we increase the scale, we encounter some sporadic issues. We find that the larger Switch-C model, with 1.6T parameters and 2048 experts, exhibits no training instability at all. Instead, the Switch XXL version, with nearly 10x larger FLOPs per sequence, is sometimes unstable. As a result, though this is our better model on a step-basis, we do not pre-train for a full 1M steps, in-line with the final reported results of T5 [(Raffel et al.,](#page-37-0) [2019)](#page-37-0).
<span id="page-22-1"></span><sup>10.</sup> This reported quality difference is a lower bound, and may actually be larger. The T5-XXL was pretrained on an easier C4 data set which included duplicated, and thus easily copied, snippets within examples.
Reasoning fine-tuning performance. As a preliminary assessment of the model quality, we use a Switch-XXL model partially pre-trained on 503B tokens, or approximately half the text used by the T5-XXL model. Using this checkpoint, we conduct multi-task training for efficiency, where all tasks are learned jointly, rather than individually fine-tuned. We find that SQuAD accuracy on the validation set increases to 89.7 versus state-of-the-art of 91.3. Next, the average SuperGLUE test score is recorded at 87.5 versus the T5 version obtaining a score of 89.3 compared to the state-of-the-art of 90.0 [(Wang et al.,](#page-39-5) [2019)](#page-39-5). On ANLI [(Nie et al.,](#page-37-8) [2019)](#page-37-8), Switch XXL improves over the prior state-of-the-art to get a 65.7 accuracy versus the prior best of 49.4 [(Yang et al.,](#page-39-6) [2020)](#page-39-6). We note that while the Switch-XXL has state-of-the-art Neg. Log Perp. on the upstream pre-training task, its gains have not yet fully translated to SOTA downstream performance. We study this issue more in Appendix [E.](#page-31-0)
Knowledge-based fine-tuning performance. Finally, we also conduct an early examination of the model's knowledge with three closed-book knowledge-based tasks: Natural Questions, WebQuestions and TriviaQA, without additional pre-training using Salient Span Masking [(Guu et al.,](#page-36-7) [2020)](#page-36-7). In all three cases, we observe improvements over the prior stateof-the-art T5-XXL model (without SSM). Natural Questions exact match increases to 34.4 versus the prior best of 32.8, Web Questions increases to 41.0 over 37.2, and TriviaQA increases to 47.5 versus 42.9.
Summing up, despite training on less than half the data of other models, we already find comparable, and sometimes state-of-the-art, model quality. Currently, the Switch Transformer translates substantial upstream gains better to knowledge-based tasks, than reasoning-tasks (see Appendix [E)](#page-31-0). Extracting stronger fine-tuning performance from large expert models is an active research question, and the pre-training perplexity indicates future improvements should be possible.
### <span id="page-23-0"></span>6. Related Work
The importance of scale in neural networks is widely recognized and several approaches have been proposed. Recent works have scaled models to billions of parameters through using model parallelism (e.g. splitting weights and tensors across multiple cores) [(Shazeer et al.,](#page-38-3) [2018;](#page-38-3) [Rajbhandari et al.,](#page-37-9) [2019;](#page-37-9) [Raffel et al.,](#page-37-0) [2019;](#page-37-0) [Brown et al.,](#page-35-0) [2020;](#page-35-0) [Shoeybi et al.,](#page-38-10) [2019)](#page-38-10). Alternatively, [Harlap et al.](#page-36-8) [(2018)](#page-36-8); [Huang et al.](#page-36-9) [(2019)](#page-36-9) propose using pipeline based model parallelism, where different layers are split across devices and micro-batches are pipelined to the different layers. Finally, Product Key networks [(Lample et al.,](#page-37-10) [2019)](#page-37-10) were proposed to scale up the capacity of neural networks by doing a lookup for learnable embeddings based on the incoming token representations to a given layer.
Our work studies a specific model in a class of methods that do conditional computation, where computation decisions are made dynamically based on the input. [Cho and Bengio](#page-35-8) [(2014)](#page-35-8) proposed adaptively selecting weights based on certain bit patterns occuring in the model hidden-states. [Eigen et al.](#page-35-9) [(2013)](#page-35-9) built stacked expert layers with dense matrix multiplications and ReLU activations and showed promising results on jittered MNIST and monotone speech. In computer vision [Puigcerver et al.](#page-37-11) [(2020)](#page-37-11) manually route tokens based on semantic classes during upstream pre-training and then select the relevant experts to be used according to the downstream task.
Mixture of Experts (MoE), in the context of modern deep learning architectures, was proven effective in [Shazeer et al.](#page-38-2) [(2017)](#page-38-2). That work added an MoE layer which was stacked between LSTM [(Hochreiter and Schmidhuber,](#page-36-10) [1997)](#page-36-10) layers, and tokens were separately routed to combinations of experts. This resulted in state-of-the-art results in language modeling and machine translation benchmarks. The MoE layer was reintroduced into the Transformer architecture by the Mesh Tensorflow library [(Shazeer et al.,](#page-38-3) [2018)](#page-38-3) where MoE layers were introduced as a substitute of the FFN layers, however, there were no accompanying NLP results. More recently, through advances in machine learning infrastructure, GShard [(Lepikhin et al.,](#page-37-2) [2020)](#page-37-2), which extended the XLA compiler, used the MoE Transformer to dramatically improve machine translation across 100 languages. Finally [Fan et al.](#page-35-10) [(2021)](#page-35-10) chooses a different deterministic MoE strategy to split the model parameters into non-overlapping groups of languages.
Sparsity along the sequence length dimension $(L)$ in the Transformer attention patterns has been a successful technique to reduce the attention complexity from $O(L^2)$ (Child et al., 2019; Correia et al., 2019; Sukhbaatar et al., 2019; Kitaev et al., 2020; Zaheer et al., 2020; Beltagy et al., 2020). This has enabled learning longer sequences than previously possible. This version of the Switch Transformer does not employ attention sparsity, but these techniques are complimentary, and, as future work, these could be combined to potentially improve learning on tasks requiring long contexts.
### <span id="page-24-0"></span>7. Discussion
We pose and discuss questions about the Switch Transformer, and sparse expert models generally, where sparsity refers to weights, not on attention patterns.
Isn't Switch Transformer better due to sheer parameter count? Yes, and by design! Parameters, independent of the total FLOPs used, are a useful axis to scale neural language models. Large models have been exhaustively shown to perform better [(Kaplan](#page-36-0) [et al.,](#page-36-0) [2020)](#page-36-0). But in this case, our model is more sample efficient and faster while using the same computational resources.
I don't have access to a supercomputer—is this still useful for me? Though this work has focused on extremely large models, we also find that models with as few as two experts improves performance while easily fitting within memory constraints of commonly available GPUs or TPUs (details in Appendix [D)](#page-28-2). We therefore believe our techniques are useful in small-scale settings.
Do sparse models outperform dense models on the speed-accuracy Pareto curve? Yes. Across a wide variety of different models sizes, sparse models outperform dense models per step and on wall clock time. Our controlled experiments show for a fixed amount of computation and time, sparse models outperform dense models.
I can't deploy a trillion parameter model—can we shrink these models? We cannot fully preserve the model quality, but compression rates of 10 to 100x are achievable by distilling our sparse models into dense models while achieving $\approx$30% of the quality gain of the expert model.
Why use Switch Transformer instead of a model-parallel dense model? On a time basis, Switch Transformers can be far more efficient than dense-models with sharded parameters (Figure [6)](#page-13-2). Also, we point out that this decision is *not* mutually exclusive—we can, and do, use model-parallelism in Switch Transformers, increasing the FLOPs per token, but incurring the slowdown of conventional model-parallelism.
Why aren't sparse models widely used already? The motivation to try sparse models has been stymied by the massive success of scaling dense models (the success of which is partially driven by co-adaptation with deep learning hardware as argued in [Hooker](#page-36-12) [(2020)](#page-36-12)). Further, sparse models have been subject to multiple issues including (1) model complexity, (2) training difficulties, and (3) communication costs. Switch Transformer makes strides to alleviate these issues.
### <span id="page-25-0"></span>8. Future Work
This paper lays out a simplified architecture, improved training procedures, and a study of how sparse models scale. However, there remain many open future directions which we briefly describe here:
- 1. A significant challenge is further improving training stability for the largest models. While our stability techniques were effective for our Switch-Base, Switch-Large and Switch-C models (no observed instability), they were not sufficient for Switch-XXL. We have taken early steps towards stabilizing these models, which we think may be generally useful for large models, including using regularizers for improving stability and adapted forms of gradient clipping, but this remains unsolved.
- 2. Generally we find that improved pre-training quality leads to better downstream results (Appendix [E)](#page-31-0), though we sometimes encounter striking anomalies. For instance, despite similar perplexities modeling the C4 data set, the 1.6T parameter Switch-C achieves only an 87.7 exact match score in SQuAD, which compares unfavorably to 89.6 for the smaller Switch-XXL model. One notable difference is that the Switch-XXL model applies ≈10x the FLOPS per token than the Switch-C model, even though it has ≈4x less unique parameters (395B vs 1.6T). This suggests a poorly understood dependence between fine-tuning quality, FLOPS per token and number of parameters.
- 3. Perform a comprehensive study of scaling relationships to guide the design of architectures blending data, model and expert-parallelism. Ideally, given the specs of a hardware configuration (computation, memory, communication) one could more rapidly design an optimal model. And, vice versa, this may also help in the design of future hardware.
- 4. Our work falls within the family of adaptive computation algorithms. Our approach always used identical, homogeneous experts, but future designs (facilitated by more flexible infrastructure) could support heterogeneous experts. This would enable more flexible adaptation by routing to larger experts when more computation is desired perhaps for harder examples.
- 5. Investigating expert layers outside the FFN layer of the Transformer. We find preliminary evidence that this similarly can improve model quality. In Appendix [A,](#page-26-1) we report quality improvement adding these inside Self-Attention layers, where our
layer replaces the weight matrices which produce Q, K, V. However, due to training instabilities with the bfloat16 format, we instead leave this as an area for future work.
- 6. Examining Switch Transformer in new and across different modalities. We have thus far only considered language, but we believe that model sparsity can similarly provide advantages in new modalities, as well as multi-modal networks.
This list could easily be extended, but we hope this gives a flavor for the types of challenges that we are thinking about and what we suspect are promising future directions.
### <span id="page-26-0"></span>9. Conclusion
Switch Transformers are scalable and effective natural language learners. We simplify Mixture of Experts to produce an architecture that is easy to understand, stable to train and vastly more sample efficient than equivalently-sized dense models. We find that these models excel across a diverse set of natural language tasks and in different training regimes, including pre-training, fine-tuning and multi-task training. These advances make it possible to train models with hundreds of billion to trillion parameters and which achieve substantial speedups relative to dense T5 baselines. We hope our work motivates sparse models as an effective architecture and that this encourages researchers and practitioners to consider these flexible models in natural language tasks, and beyond.
### Acknowledgments
The authors would like to thank Margaret Li who provided months of key insights into algorithmic improvements and suggestions for empirical studies. Hugo Larochelle for sage advising and clarifying comments on the draft, Irwan Bello for detailed comments and careful revisions, Colin Raffel and Adam Roberts for timely advice on neural language models and the T5 code-base, Yoshua Bengio for advising and encouragement on research in adaptive computation, Jascha Sohl-dickstein for interesting new directions for stabilizing new large scale models and paper revisions, and the Google Brain Team for useful discussions on the paper. Blake Hechtman who provided invaluable help in profiling and improving the training performance of our models.
### <span id="page-26-1"></span>A. Switch for Attention
[Shazeer et al.](#page-38-3) [(2018)](#page-38-3); [Lepikhin et al.](#page-37-2) [(2020)](#page-37-2) designed MoE Transformers [(Shazeer et al.,](#page-38-2) [2017)](#page-38-2) by adding MoE layers into the dense feedfoward network (FFN) computations of the Transformer. Similarly, our work also replaced the FFN layer in the Transformer, but we briefly explore here an alternate design. We add Switch layers into the Transformer Self-Attention layers. To do so, we replace the trainable weight matrices that produce the queries, keys and values with Switch layers as seen in Figure [10.](#page-27-0)
Table [10](#page-27-1) records the quality after a fixed number of steps as well as training time for several variants. Though we find improvements, we also found these layers to be more unstable when using bfloat16 precision and thus we did not include them in the final variant.
![](_page_27_Figure_1.jpeg)
- <span id="page-27-0"></span>Figure 10: Switch layers in attention. We diagram how to incorporate the Switch layer into the Self-Attention transformer block. For each token (here we show two tokens, $x_1$ = “More” and $x_2$ = “Parameters”), one set of weights produces the query and the other set of unique weights produces the shared keys and values. We experimented with each expert being a linear operation, as well as a FFN, as was the case throughout this work. While we found quality improvements using this, we found this to be more unstable when used with low precision number formats, and thus leave it for future work.
| | | | | | However, when these layers do train stably, we believe the preliminary positive results | |
|----------------------------------------|--|--|--|--|-----------------------------------------------------------------------------------------|--|
| suggests a future promising direction. | | | | | | |
| Model | Precision | Quality<br>@100k Steps (↑) | Quality<br>@16H (↑) | Speed<br>(ex/sec) (↑) |
|------------------------|-----------|----------------------------|---------------------|-----------------------|
| Experts FF | float32 | -1.548 | -1.614 | 1480 |
| Expert Attention | float32 | -1.524 | <b>-1.606</b> | 1330 |
| Expert Attention | bfloat16 | [diverges] | [diverges] | |
| Experts FF + Attention | float32 | <b>-1.513</b> | -1.607 | 1240 |
| Expert FF + Attention | bfloat16 | [diverges] | [diverges] | |
- <span id="page-27-1"></span>Table 10: Switch attention layer results. All models have 32 experts and train with 524k tokens per batch. Experts FF is when experts replace the FFN in the Transformer, which is our standard setup throughout the paper. Experts FF + Attention is when experts are used to replace both the FFN and the Self-Attention layers. When training with bfloat16 precision the models that have experts attention diverge.
### <span id="page-28-0"></span>B. Preventing Token Dropping with No-Token-Left-Behind
Due to software constraints on TPU accelerators, the shapes of our Tensors must be statically sized. As a result, each expert has a finite and fixed capacity to process token representations. This, however, presents an issue for our model which dynamically routes tokens at run-time that may result in an uneven distribution over experts. If the number of tokens sent to an expert is less than the expert capacity, then the computation may simply be padded an inefficient use of the hardware, but mathematically correct. However, when the number of tokens sent to an expert is larger than its capacity (expert overflow), a protocol is needed to handle this. [Lepikhin et al.](#page-37-2) [(2020)](#page-37-2) adapts a Mixture-of-Expert model and addresses expert overflow by passing its representation to the next layer without processing through a residual connection which we also follow.
We suspected that having no computation applied to tokens could be very wasteful, especially since if there is overflow on one expert, that means another expert will have extra capacity. With this intuition we create *No-Token-Left-Behind*, which iteratively reroutes any tokens that are at first routed to an expert that is overflowing. Figure [11](#page-29-0) shows a graphical description of this method, which will allow us to guarantee almost no tokens will be dropped during training and inference. We hypothesised that this could improve performance and further stabilize training, but we found no empirical benefits. We suspect that once the network learns associations between different tokens and experts, if this association is changed (e.g. sending a token to its second highest expert) then performance could be degraded.
### <span id="page-28-1"></span>C. Encouraging Exploration Across Experts
At each expert-layer, the router determines to which expert to send the token. This is a discrete decision over the available experts, conditioned on information about the token's representation. Based on the incoming token representation, the router determines the best expert, however, it receives no counterfactual information about how well it would have done selecting an alternate expert. As in reinforcement learning, a classic explorationexploitation dilemma arises [(Sutton and Barto,](#page-38-12) [2018)](#page-38-12). These issues have been similarly noted and addressed differently by [Rosenbaum et al.](#page-38-13) [(2017)](#page-38-13) which demonstrated success in multi-task learning. This particular setting most closely matches that of a contextual bandit [(Robbins,](#page-37-12) [1952)](#page-37-12). Deterministically selecting the top expert always amounts to an exploitative strategy we consider balancing exploration to seek better expert assignment.
To introduce exploration, we consider several approaches: 1) deterministic or argmax 2) sampling from the softmax distribution 3) input dropout on the incoming representation 4) multiplicative jitter noise on the incoming representation. The resulting impact on model quality is reported in Table [11.](#page-29-1) Throughout this work, we use input jitter to inject noise as we have found it to empirically perform the best.
### <span id="page-28-2"></span>D. Switch Transformers in Lower Compute Regimes
Switch Transformer is also an effective architecture at small scales as well as in regimes with thousands of cores and trillions of parameters. Many of our prior experiments were
![](_page_29_Figure_1.jpeg)
- <span id="page-29-0"></span>Figure 11: Diagram of the No-Token-Left-Behind Routing. Stage 1 is equivalent to Switch routing where tokens are routed to the expert with the highest probability from the router. In Stage 2 we look at all tokens that have overflowed and route them to the expert with which has the second highest probability. Tokens can still be overflowed if their second highest expert has too many tokens, but this allows most of the tokens to be routed. This process can be iterated to guarantee virtually no tokens are dropped at all.
| Model | Quality (Neg. Log Perp.) (↑) |
|----------------|------------------------------|
| Argmax | -1.471 |
| Sample softmax | -1.570 |
| Input dropout | -1.480 |
| Input jitter | <b>-1.468</b> |
- <span id="page-29-1"></span>Table 11: Router Exploration Strategies. Quality of the Switch Transformer, measured by the negative log perplexity, under different randomness-strategies for selecting the expert (lower is better). There is no material speed performance difference between the variants.
at the scale of 10B+ parameter models, but we show in Figure [12](#page-30-0) as few as 2 experts produce compelling gains over a FLOP-matched counterpart. Even if a super computer is not readily available, training Switch Transformers with 2, 4, or 8 experts (as we typically recommend one expert per core) results in solid improvements over T5 dense baselines.
![](_page_30_Figure_1.jpeg)
<span id="page-30-0"></span>Figure 12: Switch Transformer with few experts. Switch Transformer improves over the baseline even with very few experts. Here we show scaling properties at very small scales, where we improve over the T5-Base model using 2, 4, and 8 experts.
### <span id="page-31-0"></span>E. Relation of Upstream to Downstream Model Performance
There is no guarantee that a model's quality on a pre-training objective will translate to downstream task results. Figure [13](#page-31-1) presents the correlation of the upstream model quality, for both dense and Switch models, on the C4 pre-training task with two downstream task measures: average SuperGLUE performance and TriviaQA score. We choose these two tasks as one probes the model's reasoning and the other factual knowledge.
![](_page_31_Figure_3.jpeg)
<span id="page-31-1"></span>Figure 13: Upstream pre-trained quality to downstream model quality. We correlate the upstream performance with downstream quality on both SuperGLUE and TriviaQA (SOTA recorded without SSM), reasoning and knowledge-heavy benchmarks, respectively (validation sets). We find that, as with the baseline, the Switch model scales with improvements in the upstream pre-training task. For SuperGLUE, we find a loosely linear relation between negative log perplexity and the average SuperGLUE score. However, the dense model often performs better for a fixed perplexity, particularly in the large-scale regime. Conversely, on the knowledge-heavy task, TriviaQA, we find that the Switch Transformer may follow an improved scaling relationship for a given upstream perplexity, it does better than a dense counterpart. Further statistics (expensive to collect and left to future work) would be necessary to confirm these observations.
We find a consistent correlation, indicating that for both baseline and Switch models, improved pre-training leads to better downstream results. Additionally, for a fixed upstream perplexity we find that both Switch and dense models perform similarly in the small to medium model size regime. However, in the largest model regime (T5-11B/T5-XXL) our largest Switch models, as mentioned in Section [5.6,](#page-21-2) do not always translate their upstream perplexity well to downstream fine-tuning on the SuperGLUE task. This warrants future investigation and study to fully realize the potential of sparse models. Understanding the fine-tuning dynamics with expert-models is very complicated and is dependent on regularization, load-balancing, and fine-tuning hyper-parameters.
### <span id="page-32-0"></span>F. Pseudo Code for Switch Transformers
Pseudocode for Switch Transformers in Mesh Tensorflow [(Shazeer et al.,](#page-38-3) [2018)](#page-38-3). No model parallelism is being used for the below code (see [5.4](#page-21-0) for more details).
```
import mesh tensorflow as mtf
def load balance loss(router probs, expert mask):
"""Calculate loadbalancing loss to ensure diverse expert routing."""
# router probs is the probability assigned for each expert per token.
# router probs shape: [num cores, tokens per core, num experts]
# expert index contains the expert with the highest router probability in onehot format.
# expert mask shape: [num cores, tokens per core, num experts]
# For each core, get the fraction of tokens routed to each expert.
# density 1 shape: [num cores, num experts]
density 1 = mtf.reduce mean(expert mask, reduced dim=tokens per core)
# For each core, get fraction of probability mass assigned to each expert
# from the router across all tokens.
# density 1 proxy shape: [num cores, num experts]
density 1 proxy = mtf.reduce mean(router probs, reduced dim=tokens per core)
# density l for a single core: vector of length num experts that sums to 1.
# density l proxy for a single core: vector of length num experts that sums to 1.
# Want both vectors to have uniform allocation (1/num experts) across all num expert elements.
# The two vectors will be pushed towards uniform allocation when the dot product is minimized.
loss = mtf.reduce mean(density 1 proxy density 1) (num experts ˆ 2)
return loss
```
- Figure 14: Pseudo code for the load balance loss for Switch Transformers in Mesh Tensorflow.
```
import mesh tensorflow as mtf
```
```
def router(inputs, capacity factor):
"""Produce the combine and dispatch tensors used for sending and
receiving tokens from their highest probability expert. """
# Core layout is split across num cores for all tensors and operations.
# inputs shape: [num cores, tokens per core, d model]
```
#### router weights = mtf.Variable(shape=[d model, num experts])
```
# router logits shape: [num cores, tokens per core, num experts]
router logits = mtf.einsum([inputs, router weights], reduced dim=d model)
```
#### if is training:
# Add noise for exploration across experts. router logits += mtf.random uniform(shape=router logits.shape, minval=1eps, maxval=1+eps)
```
# Convert input to softmax operation from bfloat16 to float32 for stability.
router logits = mtf.to float32(router logits)
```
```
# Probabilities for each token of what expert it should be sent to.
router probs = mtf.softmax(router logits, axis=1)
```
```
# Get the top1 expert for each token. expert gate is the top1 probability
# from the router for each token. expert index is what expert each token
# is going to be routed to.
# expert gate shape: [num cores, tokens per core]
# expert index shape: [num cores, tokens per core]
expert gate, expert index = mtf.top 1(router probs, reduced dim=num experts)
# expert mask shape: [num cores, tokens per core, num experts]
expert mask = mtf.one hot(expert index, dimension=num experts)
# Compute load balancing loss.
aux loss = load balance loss(router probs, expert mask)
# Experts have a fixed capacity, ensure we do not exceed it. Construct
# the batch indices, to each expert, with position in expert
# make sure that not more that expert capacity examples can be routed to
# each expert.
position in expert = mtf.cumsum(expert mask, dimension=tokens per core) expert mask
# Keep only tokens that fit within expert capacity.
expert mask = mtf.less(position in expert, expert capacity)
expert mask flat = mtf.reduce sum(expert mask, reduced dim=experts dim)
# Mask out the experts that have overflowed the expert capacity.
expert gate = expert mask flat
# combine tensor used for combining expert outputs and scaling with router probability.
# combine tensor shape: [num cores, tokens per core, num experts, expert capacity]
combine tensor = (
expert gate expert mask flat
mtf.one hot(expert index, dimension=num experts)
mtf.one hot(position in expert, dimension=expert capacity))
# Cast back outputs to bfloat16 for the rest of the layer.
combine tensor = mtf.to bfloat16(combine tensor)
# Create binary dispatch tensor that is 1 if the token gets routed to the corresponding expert.
# dispatch tensor shape: [num cores, tokens per core, num experts, expert capacity]
dispatch tensor = mtf.cast(combine tensor, tf.bool)
```
```
return dispatch tensor, combine tensor, aux loss
```
<span id="page-33-0"></span>Figure 15: Pseudo code for the router for Switch Transformers in Mesh Tensorflow.
```
import mesh tensorflow as mtf
```
def switch layer(inputs, n, capacity factor, num experts): """Distributed switch transformer feedforward layer."""
```
# num cores (n) = total cores for training the model (scalar).
# d model = model hidden size (scalar).
# num experts = total number of experts.
# capacity factor = extra buffer for each expert.
# inputs shape: [batch, seq len, d model]
batch, seq len, d model = inputs.get shape()
# Each core will route tokens per core tokens to the correct experts.
tokens per core = batch seq len / num cores
# Each expert will have shape [num cores, expert capacity, d model].
# Each core is responsible for sending expert capacity tokens
# to each expert.
expert capacity = tokens per core capacity factor / num experts
# Reshape to setup per core expert dispatching.
# shape: [batch, seq len, d model] > [num cores, tokens per core, d model]
# Core layout: [n, 1, 1] > [n, 1, 1]
inputs = mtf.reshape(inputs, [num cores, tokens per core, d model])
# Core Layout: [n, 1, 1] > [n, 1, 1, 1], [n, 1, 1, 1]
# dispatch tensor (boolean) shape: [num cores, tokens per core, num experts, expert capacity]
# dispatch tensor is used for routing tokens to the correct expert.
# combine tensor (float) shape: [num cores, tokens per core, num experts, expert capacity]
# combine tensor used for combining expert outputs and scaling with router
# probability.
dispatch tensor, combine tensor, aux loss = router(inputs, expert capacity)
# Matmul with large boolean tensor to assign tokens to the correct expert.
# Core Layout: [n, 1, 1], > [1, n, 1, 1]
# expert inputs shape: [num experts, num cores, expert capacity, d model]
expert inputs = mtf.einsum([inputs, dispatch tensor], reduce dims=[tokens per core])
# AlltoAll communication. Cores split across num cores and now we want to split
# across num experts. This sends tokens, routed locally, to the correct expert now
# split across different cores.
# Core layout: [1, n, 1, 1] > [n, 1, 1, 1]
expert inputs = mtf.reshape(expert inputs, [num experts, num cores, expert capacity, d model])
# Standard feed forward computation, where each expert will have its own
# unique set of parameters.
# Total unique parameters created: num experts (d model d ff 2).
# expert outputs shape: [num experts, num cores, expert capacity, d model]
expert outputs = feed forward(expert inputs)
# AlltoAll communication. Cores are currently split across the experts
# dimension, which needs to be switched back to being split across num cores.
# Core Layout: [n, 1, 1, 1] > [1, n, 1, 1]
expert outputs = mtf.reshape(expert outputs, [num experts, num cores, expert capacity, d model])
# Convert back to input shape and multiply outputs of experts by the routing probability.
# expert outputs shape: [num experts, num cores, tokens per core, d model]
# expert outputs combined shape: [num cores, tokens per core, d model]
# Core Layout: [1, n, 1, 1] > [n, 1, 1]
expert outputs combined = mtf.einsum([expert outputs, combine tensor], reduce dims=[tokens per core])
# Remove tokens per core shapes used for local routing dispatching to match input shape.
# Core Layout: [n, 1, 1] > [n, 1, 1]
outputs = mtf.reshape(expert outputs combined, [batch, seq len, d model])
return outputs, aux loss
```
Figure 16: Pseudo code of the Switch Transformer layer in Mesh Tensorflow.
### References
- <span id="page-35-3"></span>Mart´ın Abadi, Paul Barham, Jianmin Chen, Zhifeng Chen, Andy Davis, Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Geoffrey Irving, Michael Isard, et al. Tensorflow: A system for large-scale machine learning. In 12th {USENIX} symposium on operating systems design and implementation ({OSDI} 16), pages 265283, 2016.
- <span id="page-35-13"></span>Iz Beltagy, Matthew E Peters, and Arman Cohan. Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150, 2020.
- <span id="page-35-7"></span>Jonathan Berant, Andrew Chou, Roy Frostig, and Percy Liang. Semantic parsing on freebase from question-answer pairs. In Proceedings of the 2013 conference on empirical methods in natural language processing, pages 15331544, 2013.
- <span id="page-35-0"></span>Tom B Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. arXiv preprint arXiv:2005.14165, 2020.
- <span id="page-35-11"></span>Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
- <span id="page-35-8"></span>Kyunghyun Cho and Yoshua Bengio. Exponentially increasing the capacity-to-computation ratio for conditional computation in deep learning. arXiv preprint arXiv:1406.7362, 2014.
- <span id="page-35-6"></span>Peter Clark, Isaac Cowhey, Oren Etzioni, Tushar Khot, Ashish Sabharwal, Carissa Schoenick, and Oyvind Tafjord. Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv preprint arXiv:1803.05457, 2018.
- <span id="page-35-12"></span>Gon¸calo M Correia, Vlad Niculae, and Andr´e FT Martins. Adaptively sparse transformers. arXiv preprint arXiv:1909.00015, 2019.
- <span id="page-35-5"></span>Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pretraining of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
- <span id="page-35-9"></span>David Eigen, Marc'Aurelio Ranzato, and Ilya Sutskever. Learning factored representations in a deep mixture of experts. arXiv preprint arXiv:1312.4314, 2013.
- <span id="page-35-10"></span>Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, et al. Beyond english-centric multilingual machine translation. Journal of Machine Learning Research, 22(107):148, 2021.
- <span id="page-35-4"></span>William Fedus, Ian Goodfellow, and Andrew M Dai. Maskgan: Better text generation via filling in the . arXiv preprint arXiv:1801.07736, 2018.
- <span id="page-35-2"></span>Trevor Gale, Matei Zaharia, Cliff Young, and Erich Elsen. Sparse gpu kernels for deep learning. arXiv preprint arXiv:2006.10901, 2020.
- <span id="page-35-1"></span>Scott Gray, Alec Radford, and Diederik P Kingma. Gpu kernels for block-sparse weights. https://openai.com/blog/block-sparse-gpu-kernels/, 2017.
- <span id="page-36-7"></span>Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat, and Ming-Wei Chang. Realm: Retrieval-augmented language model pre-training. arXiv preprint arXiv:2002.08909, 2020.
- <span id="page-36-8"></span>Aaron Harlap, Deepak Narayanan, Amar Phanishayee, Vivek Seshadri, Nikhil Devanur, Greg Ganger, and Phil Gibbons. Pipedream: Fast and efficient pipeline parallel dnn training. arXiv preprint arXiv:1806.03377, 2018.
- <span id="page-36-4"></span>Karl Moritz Hermann, Tomas Kocisky, Edward Grefenstette, Lasse Espeholt, Will Kay, Mustafa Suleyman, and Phil Blunsom. Teaching machines to read and comprehend. In C. Cortes, N. Lawrence, D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 28, pages 16931701. Curran Associates, Inc., 2015. URL [https://proceedings.neurips.cc/paper/2015/file/](https://proceedings.neurips.cc/paper/2015/file/afdec7005cc9f14302cd0474fd0f3c96-Paper.pdf) [afdec7005cc9f14302cd0474fd0f3c96-Paper.pdf](https://proceedings.neurips.cc/paper/2015/file/afdec7005cc9f14302cd0474fd0f3c96-Paper.pdf).
- <span id="page-36-3"></span>Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
- <span id="page-36-10"></span>Sepp Hochreiter and J¨urgen Schmidhuber. Long short-term memory. Neural computation, 9(8):17351780, 1997.
- <span id="page-36-12"></span>Sara Hooker. The hardware lottery. arXiv preprint arXiv:2009.06489, 2020.
- <span id="page-36-9"></span>Yanping Huang, Youlong Cheng, Ankur Bapna, Orhan Firat, Dehao Chen, Mia Chen, HyoukJoong Lee, Jiquan Ngiam, Quoc V Le, Yonghui Wu, et al. Gpipe: Efficient training of giant neural networks using pipeline parallelism. In Advances in neural information processing systems, pages 103112, 2019.
- <span id="page-36-1"></span>Robert A Jacobs, Michael I Jordan, Steven J Nowlan, and Geoffrey E Hinton. Adaptive mixtures of local experts. Neural computation, 3(1):7987, 1991.
- <span id="page-36-2"></span>Michael I Jordan and Robert A Jacobs. Hierarchical mixtures of experts and the em algorithm. Neural computation, 6(2):181214, 1994.
- <span id="page-36-6"></span>Mandar Joshi, Eunsol Choi, Daniel S Weld, and Luke Zettlemoyer. Triviaqa: A large scale distantly supervised challenge dataset for reading comprehension. arXiv preprint arXiv:1705.03551, 2017.
- <span id="page-36-0"></span>Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020.
- <span id="page-36-11"></span>Nikita Kitaev, Lukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
- <span id="page-36-5"></span>Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, Michael Collins, Ankur Parikh, Chris Alberti, Danielle Epstein, Illia Polosukhin, Jacob Devlin, Kenton Lee, et al. Natural questions: a benchmark for question answering research. Transactions of the Association for Computational Linguistics, 7:453466, 2019.
- <span id="page-37-10"></span>Guillaume Lample, Alexandre Sablayrolles, Marc'Aurelio Ranzato, Ludovic Denoyer, and Herv´e J´egou. Large memory layers with product keys. In Advances in Neural Information Processing Systems, pages 85488559, 2019.
- <span id="page-37-5"></span>Katherine Lee, Daphne Ippolito, Andrew Nystrom, Chiyuan Zhang, Douglas Eck, Chris Callison-Burch, and Nicholas Carlini. Deduplicating training data makes language models better. arXiv preprint arXiv:2107.06499, 2021.
- <span id="page-37-2"></span>Dmitry Lepikhin, HyoukJoong Lee, Yuanzhong Xu, Dehao Chen, Orhan Firat, Yanping Huang, Maxim Krikun, Noam Shazeer, and Zhifeng Chen. Gshard: Scaling giant models with conditional computation and automatic sharding. arXiv preprint arXiv:2006.16668, 2020.
- <span id="page-37-4"></span>Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, et al. Mixed precision training. arXiv preprint arXiv:1710.03740, 2017.
- <span id="page-37-6"></span>Shashi Narayan, Shay B Cohen, and Mirella Lapata. Don't give me the details, just the summary! topic-aware convolutional neural networks for extreme summarization. arXiv preprint arXiv:1808.08745, 2018.
- <span id="page-37-8"></span>Yixin Nie, Adina Williams, Emily Dinan, Mohit Bansal, Jason Weston, and Douwe Kiela. Adversarial nli: A new benchmark for natural language understanding. arXiv preprint arXiv:1910.14599, 2019.
- <span id="page-37-11"></span>Joan Puigcerver, Carlos Riquelme, Basil Mustafa, Cedric Renggli, Andr´e Susano Pinto, Sylvain Gelly, Daniel Keysers, and Neil Houlsby. Scalable transfer learning with expert models. arXiv preprint arXiv:2009.13239, 2020.
- <span id="page-37-1"></span>Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. Improving language understanding by generative pre-training, 2018.
- <span id="page-37-0"></span>Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. arXiv preprint arXiv:1910.10683, 2019.
- <span id="page-37-9"></span>Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, and Yuxiong He. Zero: Memory optimization towards training a trillion parameter models. arXiv preprint arXiv:1910.02054, 2019.
- <span id="page-37-7"></span>Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. Squad: 100,000+ questions for machine comprehension of text. arXiv preprint arXiv:1606.05250, 2016.
- <span id="page-37-3"></span>Prajit Ramachandran and Quoc V Le. Diversity and depth in per-example routing models. In International Conference on Learning Representations, 2018.
- <span id="page-37-12"></span>Herbert Robbins. Some aspects of the sequential design of experiments. Bulletin of the American Mathematical Society, 58(5):527535, 1952.
- <span id="page-38-6"></span>Adam Roberts, Colin Raffel, and Noam Shazeer. How much knowledge can you pack into the parameters of a language model? arXiv preprint arXiv:2002.08910, 2020.
- <span id="page-38-13"></span>Clemens Rosenbaum, Tim Klinger, and Matthew Riemer. Routing networks: Adaptive selection of non-linear functions for multi-task learning. arXiv preprint arXiv:1711.01239, 2017.
- <span id="page-38-7"></span>Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. Winogrande: An adversarial winograd schema challenge at scale. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pages 87328740, 2020.
- <span id="page-38-8"></span>Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter, 2019.
- <span id="page-38-9"></span>Noam Shazeer. Glu variants improve transformer, 2020.
- <span id="page-38-2"></span>Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-ofexperts layer. arXiv preprint arXiv:1701.06538, 2017.
- <span id="page-38-3"></span>Noam Shazeer, Youlong Cheng, Niki Parmar, Dustin Tran, Ashish Vaswani, Penporn Koanantakool, Peter Hawkins, HyoukJoong Lee, Mingsheng Hong, Cliff Young, et al. Mesh-tensorflow: Deep learning for supercomputers. In Advances in Neural Information Processing Systems, pages 1041410423, 2018.
- <span id="page-38-10"></span>Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, and Bryan Catanzaro. Megatron-lm: Training multi-billion parameter language models using gpu model parallelism. arXiv preprint arXiv:1909.08053, 2019.
- <span id="page-38-5"></span>Nitish Srivastava, Geoffrey E. Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 15(1):19291958, 2014. URL [http://www.cs.](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) [toronto.edu/~rsalakhu/papers/srivastava14a.pdf](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf).
- <span id="page-38-1"></span>Emma Strubell, Ananya Ganesh, and Andrew McCallum. Energy and policy considerations for deep learning in nlp. arXiv preprint arXiv:1906.02243, 2019.
- <span id="page-38-11"></span>Sainbayar Sukhbaatar, Edouard Grave, Piotr Bojanowski, and Armand Joulin. Adaptive attention span in transformers. arXiv preprint arXiv:1905.07799, 2019.
- <span id="page-38-0"></span>Rich Sutton. The Bitter Lesson. http://www.incompleteideas.net/IncIdeas/BitterLesson.html, 2019.
- <span id="page-38-12"></span>Richard S Sutton and Andrew G Barto. Reinforcement learning: An introduction. Stanford University, 2018.
- <span id="page-38-4"></span>Wilson L Taylor. "cloze procedure": A new tool for measuring readability. Journalism quarterly, 30(4):415433, 1953.
- <span id="page-39-1"></span><span id="page-39-0"></span>Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, pages 59986008, 2017.
- <span id="page-39-4"></span>Alex Wang, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R Bowman. Glue: A multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461, 2018.
- <span id="page-39-5"></span>Alex Wang, Yada Pruksachatkun, Nikita Nangia, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel Bowman. Superglue: A stickier benchmark for generalpurpose language understanding systems. In Advances in Neural Information Processing Systems, pages 32663280, 2019.
- <span id="page-39-3"></span>Shibo Wang and Pankaj Kanwar. Bfloat16: The secret to high performance on cloud tpus. Google Cloud Blog, 2019.
- <span id="page-39-2"></span>Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, and Colin Raffel. mt5: A massively multilingual pre-trained text-to-text transformer. arXiv preprint arXiv:2010.11934, 2020.
- <span id="page-39-6"></span>Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, and Quoc V. Le. Xlnet: Generalized autoregressive pretraining for language understanding, 2020.
- <span id="page-39-7"></span>Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big bird: Transformers for longer sequences. arXiv preprint arXiv:2007.14062, 2020.

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save