refactor: use llama.cpp tokenizer (#683)

* refactor: switch to llama.cpp tokenizer to simplify implementation

* refactor: remove tokenizer dependency from tabby

* refactor: renaming decoding to stop condition

* refactor: remove tokenizer dependency

* refactor: remove submodule

* chore: update formatting

* move tokenization to c++
This commit is contained in:
Meng Zhang 2023-10-31 15:16:09 -07:00 committed by GitHub
parent f15926f233
commit 296342efd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 143 additions and 1198 deletions

3
.gitmodules vendored
View File

@ -1,6 +1,3 @@
[submodule "crates/ctranslate2-bindings/CTranslate2"]
path = crates/ctranslate2-bindings/CTranslate2
url = https://github.com/OpenNMT/CTranslate2.git
[submodule "crates/llama-cpp-bindings/llama.cpp"]
path = crates/llama-cpp-bindings/llama.cpp
url = https://github.com/TabbyML/llama.cpp

View File

@ -10,6 +10,7 @@
* Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638
* add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637
* Switch cuda backend to llama.cpp: https://github.com/TabbyML/tabby/pull/656
* Switch tokenizer to llama.cpp, so tabby no longer need to download additional tokenizer file: https://github.com/TabbyML/tabby/pull/683
# v0.4.0

457
Cargo.lock generated
View File

@ -17,17 +17,6 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aes"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "ahash"
version = "0.8.3"
@ -39,15 +28,6 @@ dependencies = [
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac"
dependencies = [
"memchr",
]
[[package]]
name = "aho-corasick"
version = "1.1.2"
@ -309,12 +289,6 @@ version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
[[package]]
name = "base64ct"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bitflags"
version = "1.3.2"
@ -373,27 +347,6 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
[[package]]
name = "bzip2"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8"
dependencies = [
"bzip2-sys",
"libc",
]
[[package]]
name = "bzip2-sys"
version = "0.1.11+1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
dependencies = [
"cc",
"libc",
"pkg-config",
]
[[package]]
name = "cached"
version = "0.46.0"
@ -412,28 +365,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "cached-path"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3"
dependencies = [
"flate2",
"fs2",
"glob",
"indicatif 0.16.2",
"log",
"rand",
"reqwest",
"serde",
"serde_json",
"sha2",
"tar",
"tempfile",
"thiserror",
"zip",
]
[[package]]
name = "cached_proc_macro"
version = "0.18.0"
@ -494,16 +425,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
]
[[package]]
name = "clap"
version = "4.3.0"
@ -584,12 +505,6 @@ dependencies = [
"windows-sys 0.45.0",
]
[[package]]
name = "constant_time_eq"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
[[package]]
name = "core-foundation"
version = "0.9.3"
@ -694,24 +609,6 @@ dependencies = [
"typenum",
]
[[package]]
name = "ctranslate2-bindings"
version = "0.5.0-dev"
dependencies = [
"async-stream",
"async-trait",
"cmake",
"cxx",
"cxx-build",
"derive_builder",
"futures",
"rust-cxx-cmake-bridge",
"tabby-inference",
"tokenizers",
"tokio",
"tokio-util",
]
[[package]]
name = "cxx"
version = "1.0.95"
@ -887,7 +784,6 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
@ -967,15 +863,6 @@ dependencies = [
"backtrace",
]
[[package]]
name = "esaxx-rs"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f748b253ceca9fed5f42f8b5ceb3851e93102199bc25b64b65369f76e5c0a35"
dependencies = [
"cc",
]
[[package]]
name = "fastdivide"
version = "0.4.0"
@ -1010,18 +897,6 @@ dependencies = [
"regex",
]
[[package]]
name = "filetime"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153"
dependencies = [
"cfg-if",
"libc",
"redox_syscall 0.2.16",
"windows-sys 0.48.0",
]
[[package]]
name = "fixedbitset"
version = "0.4.2"
@ -1068,16 +943,6 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fs2"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "fs4"
version = "0.6.6"
@ -1217,19 +1082,13 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "globset"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d"
dependencies = [
"aho-corasick 1.1.2",
"aho-corasick",
"bstr",
"fnv",
"log",
@ -1292,15 +1151,6 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "htmlescape"
version = "0.3.1"
@ -1476,30 +1326,6 @@ dependencies = [
"serde",
]
[[package]]
name = "indicatif"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7baab56125e25686df467fe470785512329883aab42696d661247aca2a2896e4"
dependencies = [
"console",
"lazy_static",
"number_prefix 0.3.0",
"regex",
]
[[package]]
name = "indicatif"
version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d207dc617c7a380ab07ff572a6e52fa202a2a8f355860ac9c38e23f8196be1b"
dependencies = [
"console",
"lazy_static",
"number_prefix 0.4.0",
"regex",
]
[[package]]
name = "indicatif"
version = "0.17.3"
@ -1507,20 +1333,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cef509aa9bc73864d6756f0d34d35504af3cf0844373afe9b8669a5b8005a729"
dependencies = [
"console",
"number_prefix 0.4.0",
"number_prefix",
"portable-atomic 0.3.20",
"unicode-width",
]
[[package]]
name = "inout"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
dependencies = [
"generic-array",
]
[[package]]
name = "instant"
version = "0.1.12"
@ -1562,24 +1379,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "itertools"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.10.5"
@ -1694,7 +1493,6 @@ dependencies = [
"derive_builder",
"futures",
"tabby-inference",
"tokenizers",
"tokio",
"tokio-util",
]
@ -1747,22 +1545,6 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8"
[[package]]
name = "macro_rules_attribute"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
dependencies = [
"macro_rules_attribute-proc_macro",
"paste",
]
[[package]]
name = "macro_rules_attribute-proc_macro"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]]
name = "matchers"
version = "0.0.1"
@ -1890,27 +1672,6 @@ dependencies = [
"windows-sys 0.45.0",
]
[[package]]
name = "monostate"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a"
dependencies = [
"monostate-impl",
"serde",
]
[[package]]
name = "monostate-impl"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.28",
]
[[package]]
name = "multimap"
version = "0.8.3"
@ -2007,12 +1768,6 @@ dependencies = [
"libc",
]
[[package]]
name = "number_prefix"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
[[package]]
name = "number_prefix"
version = "0.4.0"
@ -2066,28 +1821,6 @@ dependencies = [
"loom",
]
[[package]]
name = "onig"
version = "6.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f"
dependencies = [
"bitflags 1.3.2",
"libc",
"once_cell",
"onig_sys",
]
[[package]]
name = "onig_sys"
version = "69.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7"
dependencies = [
"cc",
"pkg-config",
]
[[package]]
name = "openssl"
version = "0.10.52"
@ -2250,35 +1983,12 @@ dependencies = [
"windows-targets 0.48.0",
]
[[package]]
name = "password-hash"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
[[package]]
name = "pbkdf2"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
dependencies = [
"digest",
"hmac",
"password-hash",
"sha2",
]
[[package]]
name = "percent-encoding"
version = "2.2.0"
@ -2500,17 +2210,6 @@ dependencies = [
"rayon-core",
]
[[package]]
name = "rayon-cond"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7"
dependencies = [
"either",
"itertools 0.8.2",
"rayon",
]
[[package]]
name = "rayon-core"
version = "1.11.0"
@ -2558,7 +2257,7 @@ version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87"
dependencies = [
"aho-corasick 1.1.2",
"aho-corasick",
"memchr",
"regex-automata 0.4.1",
"regex-syntax 0.8.1",
@ -2579,7 +2278,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b"
dependencies = [
"aho-corasick 1.1.2",
"aho-corasick",
"memchr",
"regex-syntax 0.8.1",
]
@ -2590,12 +2289,6 @@ version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
[[package]]
name = "regex-syntax"
version = "0.8.1"
@ -2946,17 +2639,6 @@ dependencies = [
"trackable",
]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha2"
version = "0.10.6"
@ -3029,18 +2711,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "spm_precompiled"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326"
dependencies = [
"base64 0.13.1",
"nom 7.1.3",
"serde",
"unicode-segmentation",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
@ -3093,12 +2763,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "subtle"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
[[package]]
name = "syn"
version = "1.0.109"
@ -3215,7 +2879,7 @@ dependencies = [
"async-trait",
"cached",
"futures-util",
"indicatif 0.17.3",
"indicatif",
"reqwest",
"serde",
"serde_json",
@ -3237,7 +2901,6 @@ dependencies = [
"futures",
"regex",
"tabby-common",
"tokenizers",
]
[[package]]
@ -3273,7 +2936,7 @@ version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1d4675fed6fe2218ce11445374e181e864a8ffd0f28e7e0591ccfc38cd000ae"
dependencies = [
"aho-corasick 1.1.2",
"aho-corasick",
"arc-swap",
"async-trait",
"base64 0.21.2",
@ -3385,7 +3048,7 @@ checksum = "fc0c1bb43e5e8b8e05eb8009610344dbf285f06066c844032fbb3e546b3c71df"
dependencies = [
"tantivy-common",
"tantivy-fst",
"zstd 0.12.4",
"zstd",
]
[[package]]
@ -3407,17 +3070,6 @@ dependencies = [
"serde",
]
[[package]]
name = "tar"
version = "0.4.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6"
dependencies = [
"filetime",
"libc",
"xattr",
]
[[package]]
name = "temp_testdir"
version = "0.2.3"
@ -3538,42 +3190,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aea68938177975ab09da68552b720eac941779ff386baceaf77e0f5f9cea645f"
dependencies = [
"aho-corasick 0.7.20",
"cached-path",
"clap",
"derive_builder",
"dirs",
"esaxx-rs",
"getrandom",
"indicatif 0.15.0",
"itertools 0.9.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"reqwest",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokio"
version = "1.28.2"
@ -4084,33 +3700,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-normalization-alignments"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de"
dependencies = [
"smallvec",
]
[[package]]
name = "unicode-segmentation"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
[[package]]
name = "unicode-width"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
[[package]]
name = "unicode_categories"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "url"
version = "2.3.1"
@ -4590,42 +4185,16 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "xattr"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc"
dependencies = [
"libc",
]
[[package]]
name = "zip"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
dependencies = [
"aes",
"byteorder",
"bzip2",
"constant_time_eq",
"crc32fast",
"crossbeam-utils",
"flate2",
"hmac",
"pbkdf2",
"sha1",
"time 0.3.26",
"zstd 0.11.2+zstd.1.5.2",
]
[[package]]
name = "zstd"
version = "0.11.2+zstd.1.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
dependencies = [
"zstd-safe 5.0.2+zstd.1.5.2",
]
[[package]]
@ -4634,17 +4203,7 @@ version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c"
dependencies = [
"zstd-safe 6.0.6",
]
[[package]]
name = "zstd-safe"
version = "5.0.2+zstd.1.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
dependencies = [
"libc",
"zstd-sys",
"zstd-safe",
]
[[package]]

View File

@ -6,7 +6,6 @@ members = [
"crates/tabby-scheduler",
"crates/tabby-download",
"crates/tabby-inference",
"crates/ctranslate2-bindings",
"crates/rust-cxx-cmake-bridge",
"crates/llama-cpp-bindings",
"crates/http-api-bindings",
@ -33,7 +32,6 @@ tantivy = "0.21.0"
async-trait = "0.1.72"
reqwest = { version = "0.11.18" }
derive_builder = "0.12.0"
tokenizers = "0.13.4-rc3"
futures = "0.3.28"
async-stream = "0.3.5"
regex = "1.10.0"

View File

@ -1,3 +0,0 @@
/target
/Cargo.lock
/build

View File

@ -1,16 +0,0 @@
cmake_minimum_required(VERSION 3.22)
project(ctranslate2_bindings)
add_subdirectory(CTranslate2)
add_library(dummy
src/dummy.cc
)
target_link_libraries(dummy
PRIVATE ctranslate2
)
include(cmake/export_libs.cmake)
export_all_target_libs(dummy)

@ -1 +0,0 @@
Subproject commit 8bcbeb6ff95b6906c9d5f7740fa9491431fa3e30

View File

@ -1,25 +0,0 @@
[package]
name = "ctranslate2-bindings"
version = "0.5.0-dev"
edition = "2021"
[dependencies]
cxx = "1.0"
derive_builder = { workspace = true }
tokenizers = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true }
tabby-inference = { path = "../tabby-inference" }
async-trait = { workspace = true }
futures.workspace = true
async-stream.workspace = true
[build-dependencies]
cxx-build = "1.0"
cmake = { version = "0.1", optional = true }
rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true }
[features]
default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"]
link_shared = []
link_static_cuda = []

View File

@ -1,74 +0,0 @@
use std::{env, path::PathBuf};
use cmake::Config;
use rust_cxx_cmake_bridge::read_cmake_generated;
fn main() {
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=include/ctranslate2.h");
println!("cargo:rerun-if-changed=src/ctranslate2.cc");
println!("cargo:rerun-if-changed=src/lib.rs");
let mut lib = cxx_build::bridge("src/lib.rs");
lib.file("src/ctranslate2.cc")
.flag_if_supported("-std=c++17");
if cfg!(feature = "link_shared") {
let dir = env::var("CTRANSLATE2_ROOT").unwrap();
println!("cargo:rustc-link-search=native={}/lib", dir);
println!("cargo:rustc-link-lib=ctranslate2");
lib.flag_if_supported(&format!("-I{}/include", dir));
} else {
let dst = link_static();
lib.flag_if_supported(&format!("-I{}", dst.join("include").display()));
}
lib.compile("cxxbridge");
}
fn link_static() -> PathBuf {
let mut config = Config::new(".");
config
.define("CMAKE_BUILD_TYPE", "Release")
.define("BUILD_CLI", "OFF")
.define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON")
.define("BUILD_SHARED_LIBS", "OFF");
if cfg!(target_os = "linux") {
config
.define("WITH_MKL", "OFF")
.define("OPENMP_RUNTIME", "NONE");
if cfg!(target_feature = "sse4.1") {
config.cxxflag("-msse4.1");
}
if cfg!(feature = "link_static_cuda") {
config.define("WITH_CUDA", "ON").define("WITH_CUDNN", "ON");
if cfg!(target_arch = "aarch64") {
config.cxxflag("-mcpu=native");
}
} else {
config.define("WITH_OPENBLAS", "ON");
}
} else if cfg!(target_os = "macos") {
config
.define("CMAKE_OSX_ARCHITECTURES", "arm64")
.define("WITH_ACCELERATE", "ON")
.define("WITH_MKL", "OFF")
.define("OPENMP_RUNTIME", "NONE")
.define("WITH_RUY", "ON");
} else {
panic!("Invalid target")
};
let dst = config.build();
// Read static lib from generated deps.
let cmake_generated_libs_str =
std::fs::read_to_string(format!("/{}/build/cmake_generated_libs", dst.display())).unwrap();
read_cmake_generated(&cmake_generated_libs_str);
dst
}

View File

@ -1,25 +0,0 @@
#! /bin/bash
set -e
set -x
UNAME="$(uname -s)"
case "${UNAME}" in
Linux*) MACHINE=linux;;
Darwin*) MACHINE=macos;;
*) exit 1;;
esac
rm -rf build
mkdir build && cd build
if [[ "$MACHINE" == "macos" ]]; then
CMAKE_EXTRA_OPTIONS='-DCMAKE_OSX_ARCHITECTURES=arm64 -DWITH_ACCELERATE=ON -DWITH_MKL=OFF -DOPENMP_RUNTIME=NONE -DWITH_RUY=ON'
elif [[ "$MACHINE" == "linux" ]]; then
CMAKE_EXTRA_OPTIONS='-DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP -DCUDA_NVCC_FLAGS=-Xfatbin=-compress-all -DCUDA_ARCH_LIST=Common -DCXXFLAGS=-msse4.1'
fi
cmake -DBULID_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DBUILD_CLI=OFF -DCMAKE_INSTALL_RPATH_USE_LINK_PATH=ON $CMAKE_EXTRA_OPTIONS ..
"$@"

View File

@ -1,98 +0,0 @@
################################################################################
# WARNING: to list the system libraries(ie IMPORTED) you MUST set:
# set_target_properties(your_lib PROPERTIES IMPORTED_GLOBAL TRUE)
# just after the find_package call
# cf https://gitlab.kitware.com/cmake/cmake/-/issues/17256
#
# https://stackoverflow.com/questions/32756195/recursive-list-of-link-libraries-in-cmake
# https://stackoverflow.com/questions/32197663/how-can-i-remove-the-the-location-property-may-not-be-read-from-target-error-i
function(_get_link_libraries OUTPUT_LIST TARGET)
list(APPEND VISITED_TARGETS ${TARGET})
# DO NOT switch on IMPORTED or not
# An INTERFACE library CAN have LINK_LIBRARIES!
# get_target_property(IMPORTED ${TARGET} IMPORTED)
set(LIBS "")
get_target_property(LIBS_1 ${TARGET} INTERFACE_LINK_LIBRARIES)
get_target_property(LIBS_2 ${TARGET} LINK_LIBRARIES)
list(APPEND LIBS ${LIBS_1} ${LIBS_2})
set(LIB_FILES "")
foreach(LIB ${LIBS})
if (TARGET ${LIB})
list(FIND VISITED_TARGETS ${LIB} VISITED)
if (${VISITED} EQUAL -1)
# OLD: get_target_property(LIB_FILE ${LIB} LOCATION)
# NEW:
_get_link_libraries(LINK_LIB_FILES ${LIB})
set(LIB_FILE ${LIB})
list(APPEND LIB_FILES ${LINK_LIB_FILES})
list(APPEND LIB_FILES ${LIB_FILE})
endif()
elseif(EXISTS ${LIB})
set(LIB_FILE ${LIB})
list(APPEND LIB_FILES ${LIB_FILE})
endif()
endforeach()
set(VISITED_TARGETS ${VISITED_TARGETS} PARENT_SCOPE)
set(${OUTPUT_LIST} ${LIB_FILES} PARENT_SCOPE)
endfunction()
################################################################################
function(export_all_target_libs TARGET)
# NOTE: get_target_property(CIRCUIT_LIB_LINK_LIBRARIES a_target LINK_LIBRARIES) is NOT transitive
# This function will return eg: "$<TARGET_FILE:rust_cxx>;$<TARGET_FILE:circuit_lib>;"
# b/c generator expression are evaluated LATER
# cf https://stackoverflow.com/questions/59226127/cmake-generator-expression-how-to-get-target-file-property-on-list-of-targets
set(ALL_LINK_LIBRARIES "")
_get_link_libraries(ALL_LINK_LIBRARIES ${TARGET})
message(STATUS "ALL_LINK_LIBRARIES : ${ALL_LINK_LIBRARIES}")
set(ALL_LIBS "")
set(ALL_EXTERNAL_LIBS "")
# TODO move that back into get_link_libraries
# NOTE: we MUST do it in 2 steps:
# - collect all the LINK_LIBRARIES recursively
# - loop on those and get their TARGET_FILE (if not INTERFACE_LIBRARY)
# That is b/c in get_link_libraries a INTERFACE_LIBRARY CAN have link_libraries
# but we CAN NOT evaluate generator expressions at this time.
foreach(LIB ${ALL_LINK_LIBRARIES})
# MUST skip INTERFACE else:
# CMake Error at src/CMakeLists.txt:136 (add_custom_command):
# Error evaluating generator expression:
# $<TARGET_FILE:rust_cxx>
# Target "rust_cxx" is not an executable or library.
# SHARED_LIBRARY,INTERFACE_LIBRARY,STATIC_LIBRARY
#
if (TARGET ${LIB})
get_target_property(LIB_TYPE ${LIB} TYPE)
message(STATUS "LIB_TYPE : ${LIB} = ${LIB_TYPE}")
if(NOT ${LIB_TYPE} STREQUAL "INTERFACE_LIBRARY")
set(LIB_FILE $<TARGET_FILE:${LIB}>)
list(APPEND ALL_LIBS ${LIB_FILE})
endif()
elseif(EXISTS ${LIB})
set(LIB_FILE ${LIB})
message(STATUS "LIB_TYPE : ${LIB} = EXTERNAL")
list(APPEND ALL_LIBS ${LIB_FILE})
endif()
endforeach() # LIB ${ALL_LIBS}
message(STATUS "ALL_LIBS : ${ALL_LIBS}")
# add_custom_command(ie echoing only to stdout) works but more difficult to get from build.rs
# b/c when there is "ninja: no work to do" it will NOT echo on the console
add_custom_command(
TARGET ${TARGET}
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E echo ${ALL_LIBS} > ${CMAKE_CURRENT_BINARY_DIR}/cmake_generated_libs
# OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/cmake_generated_libs
VERBATIM
)
endfunction(export_all_target_libs)

View File

@ -1,30 +0,0 @@
#pragma once
#include "rust/cxx.h"
#include <memory>
namespace tabby {
struct InferenceContext;
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context,
InferenceCallback callback,
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature
) const = 0;
};
std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,
rust::Slice<const int32_t> device_indices
);
} // namespace

View File

@ -1,135 +0,0 @@
#include "ctranslate2-bindings/include/ctranslate2.h"
#include "ctranslate2/translator.h"
#include "ctranslate2/generator.h"
namespace tabby {
TextInferenceEngine::~TextInferenceEngine() {}
template <class Model, class Child>
class TextInferenceEngineImpl : public TextInferenceEngine {
protected:
struct Options {
size_t max_decoding_length;
float sampling_temperature;
};
public:
rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context,
InferenceCallback callback,
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature
) const {
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
return process(
std::move(context),
std::move(callback),
input_tokens,
Options{max_decoding_length, sampling_temperature}
);
}
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<Child>();
impl->model_ = std::make_unique<Model>(loader);
return impl;
}
protected:
virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context,
InferenceCallback callback,
const std::vector<std::string>& tokens,
const Options& options) const = 0;
std::unique_ptr<Model> model_;
};
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
protected:
virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context,
InferenceCallback callback,
const std::vector<std::string>& tokens,
const Options& options) const override {
ctranslate2::TranslationOptions x;
x.max_decoding_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) {
bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
};
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return output_ids;
}
};
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
protected:
virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context,
InferenceCallback callback,
const std::vector<std::string>& tokens,
const Options& options) const override {
ctranslate2::GenerationOptions x;
x.include_prompt_in_result = false;
x.max_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) {
bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
};
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return output_ids;
}
};
std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,
rust::Slice<const int32_t> device_indices
) {
std::string model_type_str(model_type);
std::string model_path_str(model_path);
ctranslate2::models::ModelLoader loader(model_path_str);
loader.device = ctranslate2::str_to_device(std::string(device));
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
loader.compute_type = ctranslate2::ComputeType::AUTO;
const size_t num_cpus = std::thread::hardware_concurrency();
if (loader.device == ctranslate2::Device::CUDA) {
// When device is cuda, set parallelism to be number of thread, capped to 4 to avoid VRAM oom.
loader.num_replicas_per_device = std::min<int32_t>(num_cpus, 4);
} else if (loader.device == ctranslate2::Device::CPU){
// When device is cpu, adjust the number based on threads per replica.
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
loader.num_replicas_per_device = std::max<int32_t>(num_cpus / 4, 1);
}
if (model_type_str == "AutoModelForCausalLM") {
return DecoderImpl::create(loader);
} else if (model_type_str == "AutoModelForSeq2SeqLM") {
return EncoderDecoderImpl::create(loader);
} else {
return nullptr;
}
}
} // namespace tabby

View File

@ -1,178 +0,0 @@
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use tabby_inference::{
decoding::{DecodingFactory, IncrementalDecoding},
helpers, TextGeneration, TextGenerationOptions,
};
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc::{channel, Sender};
use tokio_util::sync::CancellationToken;
#[cxx::bridge(namespace = "tabby")]
mod ffi {
extern "Rust" {
type InferenceContext;
}
unsafe extern "C++" {
include!("ctranslate2-bindings/include/ctranslate2.h");
type TextInferenceEngine;
fn create_engine(
model_path: &str,
model_type: &str,
device: &str,
device_indices: &[i32],
) -> SharedPtr<TextInferenceEngine>;
fn inference(
&self,
context: Box<InferenceContext>,
callback: fn(
&mut InferenceContext,
// step
usize,
// token_id
u32,
// token
String,
) -> bool,
tokens: &[String],
max_decoding_length: usize,
sampling_temperature: f32,
) -> Vec<u32>;
}
}
unsafe impl Send for ffi::TextInferenceEngine {}
unsafe impl Sync for ffi::TextInferenceEngine {}
#[derive(Builder, Debug)]
pub struct CTranslate2EngineOptions {
model_path: String,
model_type: String,
tokenizer_path: String,
device: String,
device_indices: Vec<i32>,
}
pub struct InferenceContext {
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
}
impl InferenceContext {
fn new(
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
) -> Self {
InferenceContext {
sender,
decoding,
cancel,
}
}
}
pub struct CTranslate2Engine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
decoding_factory: DecodingFactory,
tokenizer: Arc<Tokenizer>,
}
impl CTranslate2Engine {
pub fn create(options: CTranslate2EngineOptions) -> Self where {
let engine = ffi::create_engine(
&options.model_path,
&options.model_type,
&options.device,
&options.device_indices,
);
return Self {
engine,
decoding_factory: DecodingFactory::default(),
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
};
}
}
#[async_trait]
impl TextGeneration for CTranslate2Engine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}
async fn generate_stream(
&self,
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length),
options.language,
);
let cancel = CancellationToken::new();
let engine = self.engine.clone();
let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel.clone());
tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
context,
inference_callback,
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
);
});
let s = stream! {
let _guard = cancel.drop_guard();
while let Some(text) = receiver.recv().await {
yield text;
}
};
Box::pin(s)
}
}
fn truncate_tokens<T>(tokens: &[T], max_length: usize) -> &[T] {
if max_length < tokens.len() {
let start = tokens.len() - max_length;
&tokens[start..]
} else {
tokens
}
}
fn inference_callback(
context: &mut InferenceContext,
_step: usize,
token_id: u32,
_token: String,
) -> bool {
if context.cancel.is_cancelled() {
true
} else if let Some(new_text) = context.decoding.next_token(token_id) {
let _ = context.sender.blocking_send(new_text);
false
} else {
true
}
}

View File

@ -16,7 +16,6 @@ async-trait = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
tabby-inference = { path = "../tabby-inference" }
derive_builder = { workspace = true }
tokenizers = { workspace = true }
tokio-util = { workspace = true }
futures.workspace = true
async-stream.workspace = true

View File

@ -4,16 +4,15 @@
#include <memory>
namespace llama {
struct StepOutput;
class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) = 0;
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) = 0;
virtual void stop_request(uint32_t request_id) = 0;
virtual rust::Vec<uint32_t> step() = 0;
virtual uint32_t eos_token_id() const = 0;
virtual rust::Vec<StepOutput> step() = 0;
};
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);

View File

@ -8,6 +8,8 @@
#include <ggml.h>
#include <llama.h>
#include "llama-cpp-bindings/src/lib.rs.h"
namespace llama {
TextInferenceEngine::~TextInferenceEngine() {}
@ -27,20 +29,56 @@ constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history.
struct Request {
Request(size_t request_id, rust::Slice<const uint32_t> input_token_ids) :
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
id(request_id),
tokens(input_token_ids.begin(), input_token_ids.end()) {
}
size_t id = -1;
uint32_t id = -1;
llama_seq_id seq_id = -1;
std::vector<llama_token> tokens;
size_t i_batch = -1;
size_t n_past = 0;
int32_t multibyte_pending = 0;
std::string generated_text;
};
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return std::string(result.data(), result.size());
}
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const rust::Str & text,
bool add_bos,
bool special) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}
template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>;
@ -56,15 +94,20 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
llama_batch_free(batch_);
}
void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) override {
pending_requests_.push_back(Request(request_id, input_token_ids));
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override {
auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
if (tokens.size() > max_input_length) {
int start = tokens.size() - max_input_length;
tokens = std::vector<llama_token>(tokens.begin() + start, tokens.end());
}
pending_requests_.push_back(Request(request_id, tokens));
}
void stop_request(uint32_t request_id) override {
stopped_requests_.insert(request_id);
}
rust::Vec<uint32_t> step() override {
rust::Vec<StepOutput> step() override {
auto* ctx = ctx_.get();
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -123,28 +166,29 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
request.i_batch = batch_.n_tokens - 1;
}
rust::Vec<uint32_t> result;
result.reserve(requests_.size() * 2);
rust::Vec<StepOutput> result;
result.reserve(requests_.size());
// Decode tokens in chunks
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
llama_batch batch_view = {
n_tokens,
batch_.token + i,
nullptr,
batch_.pos + i,
batch_.n_seq_id + i,
batch_.seq_id + i,
batch_.logits + i,
0, 0, 0, // unused
};
llama_batch batch_view = {
n_tokens,
batch_.token + i,
nullptr,
batch_.pos + i,
batch_.n_seq_id + i,
batch_.seq_id + i,
batch_.logits + i,
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
throw std::runtime_error("Failed to eval");
}
const auto eos_id = llama_token_eos(llama_get_model(ctx));
for (auto& request : requests_) {
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
continue;
@ -159,18 +203,44 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
request.tokens.clear();
request.tokens.push_back(next_token);
result.push_back(request.id);
result.push_back(next_token);
const auto token_str = llama_token_to_piece(ctx, next_token);
request.generated_text += token_str;
// FIXME: Hack for codellama to simplify tabby's implementation.
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
if (request.multibyte_pending > 0) {
request.multibyte_pending -= token_str.size();
} else if (token_str.size() == 1) {
const char c = token_str[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0) {
request.multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF0) == 0xE0) {
request.multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
} else if ((c & 0xF8) == 0xF0) {
request.multibyte_pending = 3;
}
else {
request.multibyte_pending = 0;
}
}
if (request.multibyte_pending == 0) {
rust::String generated_text = is_eos ? "" : request.generated_text;
result.push_back({request.id, generated_text});
request.generated_text.clear();
}
}
}
return result;
}
uint32_t eos_token_id() const override {
return llama_token_eos(llama_get_model(ctx_.get()));
}
private:
owned<llama_model> model_;
owned<llama_context> ctx_;

View File

@ -7,10 +7,9 @@ use derive_builder::Builder;
use ffi::create_engine;
use futures::{lock::Mutex, stream::BoxStream};
use tabby_inference::{
decoding::{DecodingFactory, IncrementalDecoding},
decoding::{StopCondition, StopConditionFactory},
helpers, TextGeneration, TextGenerationOptions,
};
use tokenizers::tokenizer::Tokenizer;
use tokio::{
sync::mpsc::{channel, Sender},
task::yield_now,
@ -18,6 +17,11 @@ use tokio::{
#[cxx::bridge(namespace = "llama")]
mod ffi {
struct StepOutput {
request_id: u32,
text: String,
}
unsafe extern "C++" {
include!("llama-cpp-bindings/include/engine.h");
@ -28,12 +32,11 @@ mod ffi {
fn add_request(
self: Pin<&mut TextInferenceEngine>,
request_id: u32,
input_token_ids: &[u32],
prompt: &str,
max_input_length: usize,
);
fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32);
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<u32>>;
fn eos_token_id(&self) -> u32;
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<StepOutput>>;
}
}
@ -42,26 +45,22 @@ unsafe impl Sync for ffi::TextInferenceEngine {}
struct InferenceRequest {
tx: Sender<String>,
decoding: IncrementalDecoding,
stop_condition: StopCondition,
}
struct AsyncTextInferenceEngine {
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
tokenizer: Arc<Tokenizer>,
decoding_factory: DecodingFactory,
stop_condition_factory: StopConditionFactory,
requests: Mutex<HashMap<u32, InferenceRequest>>,
next_request_id: Mutex<u32>,
eos_token_id: u32,
}
impl AsyncTextInferenceEngine {
fn create(engine: UniquePtr<ffi::TextInferenceEngine>, tokenizer: Tokenizer) -> Self {
fn create(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
Self {
eos_token_id: engine.eos_token_id(),
engine: Mutex::new(engine),
tokenizer: Arc::new(tokenizer),
decoding_factory: DecodingFactory::default(),
stop_condition_factory: StopConditionFactory::default(),
requests: Mutex::new(HashMap::new()),
next_request_id: Mutex::new(0),
}
@ -79,18 +78,15 @@ impl AsyncTextInferenceEngine {
panic!("Failed to evaluation");
};
for i in (0..result.len()).step_by(2) {
let request_id = result[i];
let token_id = result[i + 1];
let InferenceRequest { tx, decoding } = requests.get_mut(&request_id).unwrap();
for ffi::StepOutput { request_id, text } in result {
let mut stopped = false;
let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap();
if tx.is_closed() || token_id == self.eos_token_id {
if tx.is_closed() || text.is_empty() {
// Cancelled by client side or hit eos.
stopped = true;
} else if let Some(new_text) = decoding.next_token(token_id) {
match tx.send(new_text).await {
} else if !stop_condition.should_stop(&text) {
match tx.send(text).await {
Ok(_) => (),
Err(_) => stopped = true,
}
@ -111,25 +107,21 @@ impl AsyncTextInferenceEngine {
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(),
input_token_ids,
options.language,
);
let stop_condition = self.stop_condition_factory.create(prompt, options.language);
let (tx, mut rx) = channel::<String>(4);
{
let mut engine = self.engine.lock().await;
let engine = engine.as_mut().unwrap();
let mut request_id = self.next_request_id.lock().await;
self.requests
.lock()
.await
.insert(*request_id, InferenceRequest { tx, decoding });
engine.add_request(*request_id, input_token_ids);
.insert(*request_id, InferenceRequest { tx, stop_condition });
engine
.as_mut()
.unwrap()
.add_request(*request_id, prompt, options.max_input_length);
// 2048 should be large enough to avoid collision.
*request_id = (*request_id + 1) % 2048;
@ -155,7 +147,6 @@ impl AsyncTextInferenceEngine {
#[derive(Builder, Debug)]
pub struct LlamaTextGenerationOptions {
model_path: String,
tokenizer_path: String,
use_gpu: bool,
}
@ -169,9 +160,8 @@ impl LlamaTextGeneration {
if engine.is_null() {
panic!("Unable to load model: {}", options.model_path);
}
let tokenizer = Tokenizer::from_file(&options.tokenizer_path).unwrap();
let ret = LlamaTextGeneration {
engine: Arc::new(AsyncTextInferenceEngine::create(engine, tokenizer)),
engine: Arc::new(AsyncTextInferenceEngine::create(engine)),
};
ret.start_background_job();
ret
@ -203,12 +193,3 @@ impl TextGeneration for LlamaTextGeneration {
self.engine.generate_stream(prompt, options).await
}
}
fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] {
if max_length < tokens.len() {
let start = tokens.len() - max_length;
&tokens[start..]
} else {
tokens
}
}

View File

@ -78,14 +78,6 @@ impl ModelDir {
self.path_string("tabby.json")
}
pub fn tokenizer_file(&self) -> String {
self.path_string("tokenizer.json")
}
pub fn ctranslate2_dir(&self) -> String {
self.path_string("ctranslate2")
}
pub fn ggml_q8_0_file(&self) -> String {
self.path_string("ggml/q8_0.gguf")
}

View File

@ -29,27 +29,8 @@ impl Downloader {
}
}
pub async fn download_ctranslate2_files(&self) -> Result<()> {
let files = vec![
("tabby.json", true),
("tokenizer.json", true),
("ctranslate2/vocabulary.txt", false),
("ctranslate2/shared_vocabulary.txt", false),
("ctranslate2/vocabulary.json", false),
("ctranslate2/shared_vocabulary.json", false),
("ctranslate2/config.json", true),
("ctranslate2/model.bin", true),
];
self.download_files(&files).await
}
pub async fn download_ggml_files(&self) -> Result<()> {
let files = vec![
("tabby.json", true),
("tokenizer.json", true),
("ggml/q8_0.v2.gguf", true),
];
let files = vec![("tabby.json", true), ("ggml/q8_0.v2.gguf", true)];
self.download_files(&files).await
}

View File

@ -12,5 +12,4 @@ dashmap = "5.5.3"
derive_builder = "0.12.0"
futures = { workspace = true }
regex.workspace = true
tokenizers.workspace = true
tabby-common = { path = "../tabby-common" }

View File

@ -1,11 +1,8 @@
use std::sync::Arc;
use dashmap::DashMap;
use regex::Regex;
use tabby_common::languages::Language;
use tokenizers::tokenizer::Tokenizer;
pub struct DecodingFactory {
pub struct StopConditionFactory {
stop_regex_cache: DashMap<String, Regex>,
}
@ -16,7 +13,7 @@ where
s.into().chars().rev().collect()
}
impl Default for DecodingFactory {
impl Default for StopConditionFactory {
fn default() -> Self {
Self {
stop_regex_cache: DashMap::new(),
@ -24,14 +21,9 @@ impl Default for DecodingFactory {
}
}
impl DecodingFactory {
pub fn create_incremental_decoding(
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
language: &'static Language,
) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids)
impl StopConditionFactory {
pub fn create(&self, text: &str, language: &'static Language) -> StopCondition {
StopCondition::new(self.get_re(language), text)
}
fn get_re(&self, language: &'static Language) -> Option<Regex> {
@ -62,68 +54,31 @@ fn create_stop_regex(stop_words: Vec<String>) -> Regex {
Regex::new(&regex_string).expect("Failed to create regex")
}
pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>,
pub struct StopCondition {
stop_re: Option<Regex>,
token_ids: Vec<u32>,
prefix_offset: usize,
read_offset: usize,
reversed_text: String,
}
impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer.");
impl StopCondition {
pub fn new(stop_re: Option<Regex>, text: &str) -> Self {
Self {
tokenizer,
stop_re,
token_ids: input_token_ids.to_owned(),
prefix_offset: 0,
read_offset: input_token_ids.len(),
reversed_text: reverse(text),
}
}
pub fn next_token(&mut self, token_id: u32) -> Option<String> {
let skip_special_token = true;
self.token_ids.push(token_id);
let prefix_text = self
.tokenizer
.decode(
&self.token_ids[self.prefix_offset..self.read_offset],
skip_special_token,
)
.expect("Cannot decode token from tokenizer.");
let new_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..], skip_special_token)
.expect("Cannot decode token from tokenizer.");
let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('<27>') {
self.prefix_offset = self.read_offset;
self.read_offset = self.token_ids.len();
&new_text[prefix_text.len()..]
} else {
""
};
pub fn should_stop(&mut self, new_text: &str) -> bool {
if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text;
if let Some(re) = &self.stop_re {
if re.is_match(&self.reversed_text) {
return None;
return true;
}
}
}
Some(new_text.to_owned())
false
}
}

View File

@ -41,7 +41,6 @@ pub struct EngineInfo {
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_v2_file())
.tokenizer_path(model_dir.tokenizer_file())
.use_gpu(device.ggml_use_gpu())
.build()
.unwrap();