mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 00:08:06 +00:00
refactor: use llama.cpp tokenizer (#683)
* refactor: switch to llama.cpp tokenizer to simplify implementation * refactor: remove tokenizer dependency from tabby * refactor: renaming decoding to stop condition * refactor: remove tokenizer dependency * refactor: remove submodule * chore: update formatting * move tokenization to c++
This commit is contained in:
parent
f15926f233
commit
296342efd8
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,6 +1,3 @@
|
||||
[submodule "crates/ctranslate2-bindings/CTranslate2"]
|
||||
path = crates/ctranslate2-bindings/CTranslate2
|
||||
url = https://github.com/OpenNMT/CTranslate2.git
|
||||
[submodule "crates/llama-cpp-bindings/llama.cpp"]
|
||||
path = crates/llama-cpp-bindings/llama.cpp
|
||||
url = https://github.com/TabbyML/llama.cpp
|
||||
|
@ -10,6 +10,7 @@
|
||||
* Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638
|
||||
* add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637
|
||||
* Switch cuda backend to llama.cpp: https://github.com/TabbyML/tabby/pull/656
|
||||
* Switch tokenizer to llama.cpp, so tabby no longer need to download additional tokenizer file: https://github.com/TabbyML/tabby/pull/683
|
||||
|
||||
# v0.4.0
|
||||
|
||||
|
457
Cargo.lock
generated
457
Cargo.lock
generated
@ -17,17 +17,6 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "aes"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.3"
|
||||
@ -39,15 +28,6 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.2"
|
||||
@ -309,12 +289,6 @@ version = "0.21.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
@ -373,27 +347,6 @@ version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
|
||||
|
||||
[[package]]
|
||||
name = "bzip2"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8"
|
||||
dependencies = [
|
||||
"bzip2-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bzip2-sys"
|
||||
version = "0.1.11+1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached"
|
||||
version = "0.46.0"
|
||||
@ -412,28 +365,6 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached-path"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3"
|
||||
dependencies = [
|
||||
"flate2",
|
||||
"fs2",
|
||||
"glob",
|
||||
"indicatif 0.16.2",
|
||||
"log",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tar",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached_proc_macro"
|
||||
version = "0.18.0"
|
||||
@ -494,16 +425,6 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"inout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.3.0"
|
||||
@ -584,12 +505,6 @@ dependencies = [
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "constant_time_eq"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.3"
|
||||
@ -694,24 +609,6 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctranslate2-bindings"
|
||||
version = "0.5.0-dev"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"cmake",
|
||||
"cxx",
|
||||
"cxx-build",
|
||||
"derive_builder",
|
||||
"futures",
|
||||
"rust-cxx-cmake-bridge",
|
||||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cxx"
|
||||
version = "1.0.95"
|
||||
@ -887,7 +784,6 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -967,15 +863,6 @@ dependencies = [
|
||||
"backtrace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "esaxx-rs"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f748b253ceca9fed5f42f8b5ceb3851e93102199bc25b64b65369f76e5c0a35"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastdivide"
|
||||
version = "0.4.0"
|
||||
@ -1010,18 +897,6 @@ dependencies = [
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filetime"
|
||||
version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall 0.2.16",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.4.2"
|
||||
@ -1068,16 +943,6 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs2"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs4"
|
||||
version = "0.6.6"
|
||||
@ -1217,19 +1082,13 @@ version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||
|
||||
[[package]]
|
||||
name = "globset"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d"
|
||||
dependencies = [
|
||||
"aho-corasick 1.1.2",
|
||||
"aho-corasick",
|
||||
"bstr",
|
||||
"fnv",
|
||||
"log",
|
||||
@ -1292,15 +1151,6 @@ version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||
dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "htmlescape"
|
||||
version = "0.3.1"
|
||||
@ -1476,30 +1326,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7baab56125e25686df467fe470785512329883aab42696d661247aca2a2896e4"
|
||||
dependencies = [
|
||||
"console",
|
||||
"lazy_static",
|
||||
"number_prefix 0.3.0",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d207dc617c7a380ab07ff572a6e52fa202a2a8f355860ac9c38e23f8196be1b"
|
||||
dependencies = [
|
||||
"console",
|
||||
"lazy_static",
|
||||
"number_prefix 0.4.0",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.17.3"
|
||||
@ -1507,20 +1333,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cef509aa9bc73864d6756f0d34d35504af3cf0844373afe9b8669a5b8005a729"
|
||||
dependencies = [
|
||||
"console",
|
||||
"number_prefix 0.4.0",
|
||||
"number_prefix",
|
||||
"portable-atomic 0.3.20",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inout"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instant"
|
||||
version = "0.1.12"
|
||||
@ -1562,24 +1379,6 @@ dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.5"
|
||||
@ -1694,7 +1493,6 @@ dependencies = [
|
||||
"derive_builder",
|
||||
"futures",
|
||||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
@ -1747,22 +1545,6 @@ version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8"
|
||||
|
||||
[[package]]
|
||||
name = "macro_rules_attribute"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
|
||||
dependencies = [
|
||||
"macro_rules_attribute-proc_macro",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "macro_rules_attribute-proc_macro"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.0.1"
|
||||
@ -1890,27 +1672,6 @@ dependencies = [
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monostate"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a"
|
||||
dependencies = [
|
||||
"monostate-impl",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monostate-impl"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.8.3"
|
||||
@ -2007,12 +1768,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a"
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.4.0"
|
||||
@ -2066,28 +1821,6 @@ dependencies = [
|
||||
"loom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onig"
|
||||
version = "6.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"onig_sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onig_sys"
|
||||
version = "69.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.52"
|
||||
@ -2250,35 +1983,12 @@ dependencies = [
|
||||
"windows-targets 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "password-hash"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
|
||||
dependencies = [
|
||||
"base64ct",
|
||||
"rand_core",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
|
||||
|
||||
[[package]]
|
||||
name = "pbkdf2"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
|
||||
dependencies = [
|
||||
"digest",
|
||||
"hmac",
|
||||
"password-hash",
|
||||
"sha2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.2.0"
|
||||
@ -2500,17 +2210,6 @@ dependencies = [
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-cond"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7"
|
||||
dependencies = [
|
||||
"either",
|
||||
"itertools 0.8.2",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.11.0"
|
||||
@ -2558,7 +2257,7 @@ version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87"
|
||||
dependencies = [
|
||||
"aho-corasick 1.1.2",
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata 0.4.1",
|
||||
"regex-syntax 0.8.1",
|
||||
@ -2579,7 +2278,7 @@ version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b"
|
||||
dependencies = [
|
||||
"aho-corasick 1.1.2",
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax 0.8.1",
|
||||
]
|
||||
@ -2590,12 +2289,6 @@ version = "0.6.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.1"
|
||||
@ -2946,17 +2639,6 @@ dependencies = [
|
||||
"trackable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.6"
|
||||
@ -3029,18 +2711,6 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spm_precompiled"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"nom 7.1.3",
|
||||
"serde",
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.0"
|
||||
@ -3093,12 +2763,6 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
@ -3215,7 +2879,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"cached",
|
||||
"futures-util",
|
||||
"indicatif 0.17.3",
|
||||
"indicatif",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@ -3237,7 +2901,6 @@ dependencies = [
|
||||
"futures",
|
||||
"regex",
|
||||
"tabby-common",
|
||||
"tokenizers",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3273,7 +2936,7 @@ version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1d4675fed6fe2218ce11445374e181e864a8ffd0f28e7e0591ccfc38cd000ae"
|
||||
dependencies = [
|
||||
"aho-corasick 1.1.2",
|
||||
"aho-corasick",
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
"base64 0.21.2",
|
||||
@ -3385,7 +3048,7 @@ checksum = "fc0c1bb43e5e8b8e05eb8009610344dbf285f06066c844032fbb3e546b3c71df"
|
||||
dependencies = [
|
||||
"tantivy-common",
|
||||
"tantivy-fst",
|
||||
"zstd 0.12.4",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3407,17 +3070,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.38"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6"
|
||||
dependencies = [
|
||||
"filetime",
|
||||
"libc",
|
||||
"xattr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "temp_testdir"
|
||||
version = "0.2.3"
|
||||
@ -3538,42 +3190,6 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.13.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aea68938177975ab09da68552b720eac941779ff386baceaf77e0f5f9cea645f"
|
||||
dependencies = [
|
||||
"aho-corasick 0.7.20",
|
||||
"cached-path",
|
||||
"clap",
|
||||
"derive_builder",
|
||||
"dirs",
|
||||
"esaxx-rs",
|
||||
"getrandom",
|
||||
"indicatif 0.15.0",
|
||||
"itertools 0.9.0",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"macro_rules_attribute",
|
||||
"monostate",
|
||||
"onig",
|
||||
"paste",
|
||||
"rand",
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.7.5",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
"thiserror",
|
||||
"unicode-normalization-alignments",
|
||||
"unicode-segmentation",
|
||||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.28.2"
|
||||
@ -4084,33 +3700,12 @@ dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization-alignments"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de"
|
||||
dependencies = [
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_categories"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.3.1"
|
||||
@ -4590,42 +4185,16 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"byteorder",
|
||||
"bzip2",
|
||||
"constant_time_eq",
|
||||
"crc32fast",
|
||||
"crossbeam-utils",
|
||||
"flate2",
|
||||
"hmac",
|
||||
"pbkdf2",
|
||||
"sha1",
|
||||
"time 0.3.26",
|
||||
"zstd 0.11.2+zstd.1.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.11.2+zstd.1.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
|
||||
dependencies = [
|
||||
"zstd-safe 5.0.2+zstd.1.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4634,17 +4203,7 @@ version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c"
|
||||
dependencies = [
|
||||
"zstd-safe 6.0.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "5.0.2+zstd.1.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"zstd-sys",
|
||||
"zstd-safe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -6,7 +6,6 @@ members = [
|
||||
"crates/tabby-scheduler",
|
||||
"crates/tabby-download",
|
||||
"crates/tabby-inference",
|
||||
"crates/ctranslate2-bindings",
|
||||
"crates/rust-cxx-cmake-bridge",
|
||||
"crates/llama-cpp-bindings",
|
||||
"crates/http-api-bindings",
|
||||
@ -33,7 +32,6 @@ tantivy = "0.21.0"
|
||||
async-trait = "0.1.72"
|
||||
reqwest = { version = "0.11.18" }
|
||||
derive_builder = "0.12.0"
|
||||
tokenizers = "0.13.4-rc3"
|
||||
futures = "0.3.28"
|
||||
async-stream = "0.3.5"
|
||||
regex = "1.10.0"
|
||||
|
3
crates/ctranslate2-bindings/.gitignore
vendored
3
crates/ctranslate2-bindings/.gitignore
vendored
@ -1,3 +0,0 @@
|
||||
/target
|
||||
/Cargo.lock
|
||||
/build
|
@ -1,16 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.22)
|
||||
|
||||
project(ctranslate2_bindings)
|
||||
|
||||
add_subdirectory(CTranslate2)
|
||||
|
||||
add_library(dummy
|
||||
src/dummy.cc
|
||||
)
|
||||
|
||||
target_link_libraries(dummy
|
||||
PRIVATE ctranslate2
|
||||
)
|
||||
|
||||
include(cmake/export_libs.cmake)
|
||||
export_all_target_libs(dummy)
|
@ -1 +0,0 @@
|
||||
Subproject commit 8bcbeb6ff95b6906c9d5f7740fa9491431fa3e30
|
@ -1,25 +0,0 @@
|
||||
[package]
|
||||
name = "ctranslate2-bindings"
|
||||
version = "0.5.0-dev"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
cxx = "1.0"
|
||||
derive_builder = { workspace = true }
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tokio-util = { workspace = true }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
async-trait = { workspace = true }
|
||||
futures.workspace = true
|
||||
async-stream.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
cmake = { version = "0.1", optional = true }
|
||||
rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true }
|
||||
|
||||
[features]
|
||||
default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"]
|
||||
link_shared = []
|
||||
link_static_cuda = []
|
@ -1,74 +0,0 @@
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
use cmake::Config;
|
||||
use rust_cxx_cmake_bridge::read_cmake_generated;
|
||||
|
||||
fn main() {
|
||||
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
||||
println!("cargo:rerun-if-changed=include/ctranslate2.h");
|
||||
println!("cargo:rerun-if-changed=src/ctranslate2.cc");
|
||||
println!("cargo:rerun-if-changed=src/lib.rs");
|
||||
|
||||
let mut lib = cxx_build::bridge("src/lib.rs");
|
||||
lib.file("src/ctranslate2.cc")
|
||||
.flag_if_supported("-std=c++17");
|
||||
|
||||
if cfg!(feature = "link_shared") {
|
||||
let dir = env::var("CTRANSLATE2_ROOT").unwrap();
|
||||
println!("cargo:rustc-link-search=native={}/lib", dir);
|
||||
println!("cargo:rustc-link-lib=ctranslate2");
|
||||
lib.flag_if_supported(&format!("-I{}/include", dir));
|
||||
} else {
|
||||
let dst = link_static();
|
||||
lib.flag_if_supported(&format!("-I{}", dst.join("include").display()));
|
||||
}
|
||||
|
||||
lib.compile("cxxbridge");
|
||||
}
|
||||
|
||||
fn link_static() -> PathBuf {
|
||||
let mut config = Config::new(".");
|
||||
config
|
||||
.define("CMAKE_BUILD_TYPE", "Release")
|
||||
.define("BUILD_CLI", "OFF")
|
||||
.define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON")
|
||||
.define("BUILD_SHARED_LIBS", "OFF");
|
||||
|
||||
if cfg!(target_os = "linux") {
|
||||
config
|
||||
.define("WITH_MKL", "OFF")
|
||||
.define("OPENMP_RUNTIME", "NONE");
|
||||
|
||||
if cfg!(target_feature = "sse4.1") {
|
||||
config.cxxflag("-msse4.1");
|
||||
}
|
||||
|
||||
if cfg!(feature = "link_static_cuda") {
|
||||
config.define("WITH_CUDA", "ON").define("WITH_CUDNN", "ON");
|
||||
|
||||
if cfg!(target_arch = "aarch64") {
|
||||
config.cxxflag("-mcpu=native");
|
||||
}
|
||||
} else {
|
||||
config.define("WITH_OPENBLAS", "ON");
|
||||
}
|
||||
} else if cfg!(target_os = "macos") {
|
||||
config
|
||||
.define("CMAKE_OSX_ARCHITECTURES", "arm64")
|
||||
.define("WITH_ACCELERATE", "ON")
|
||||
.define("WITH_MKL", "OFF")
|
||||
.define("OPENMP_RUNTIME", "NONE")
|
||||
.define("WITH_RUY", "ON");
|
||||
} else {
|
||||
panic!("Invalid target")
|
||||
};
|
||||
|
||||
let dst = config.build();
|
||||
|
||||
// Read static lib from generated deps.
|
||||
let cmake_generated_libs_str =
|
||||
std::fs::read_to_string(format!("/{}/build/cmake_generated_libs", dst.display())).unwrap();
|
||||
read_cmake_generated(&cmake_generated_libs_str);
|
||||
|
||||
dst
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
#! /bin/bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
UNAME="$(uname -s)"
|
||||
case "${UNAME}" in
|
||||
Linux*) MACHINE=linux;;
|
||||
Darwin*) MACHINE=macos;;
|
||||
*) exit 1;;
|
||||
esac
|
||||
|
||||
rm -rf build
|
||||
mkdir build && cd build
|
||||
|
||||
if [[ "$MACHINE" == "macos" ]]; then
|
||||
CMAKE_EXTRA_OPTIONS='-DCMAKE_OSX_ARCHITECTURES=arm64 -DWITH_ACCELERATE=ON -DWITH_MKL=OFF -DOPENMP_RUNTIME=NONE -DWITH_RUY=ON'
|
||||
elif [[ "$MACHINE" == "linux" ]]; then
|
||||
CMAKE_EXTRA_OPTIONS='-DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP -DCUDA_NVCC_FLAGS=-Xfatbin=-compress-all -DCUDA_ARCH_LIST=Common -DCXXFLAGS=-msse4.1'
|
||||
fi
|
||||
|
||||
|
||||
cmake -DBULID_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DBUILD_CLI=OFF -DCMAKE_INSTALL_RPATH_USE_LINK_PATH=ON $CMAKE_EXTRA_OPTIONS ..
|
||||
|
||||
"$@"
|
@ -1,98 +0,0 @@
|
||||
################################################################################
|
||||
|
||||
# WARNING: to list the system libraries(ie IMPORTED) you MUST set:
|
||||
# set_target_properties(your_lib PROPERTIES IMPORTED_GLOBAL TRUE)
|
||||
# just after the find_package call
|
||||
# cf https://gitlab.kitware.com/cmake/cmake/-/issues/17256
|
||||
#
|
||||
# https://stackoverflow.com/questions/32756195/recursive-list-of-link-libraries-in-cmake
|
||||
# https://stackoverflow.com/questions/32197663/how-can-i-remove-the-the-location-property-may-not-be-read-from-target-error-i
|
||||
function(_get_link_libraries OUTPUT_LIST TARGET)
|
||||
list(APPEND VISITED_TARGETS ${TARGET})
|
||||
|
||||
# DO NOT switch on IMPORTED or not
|
||||
# An INTERFACE library CAN have LINK_LIBRARIES!
|
||||
# get_target_property(IMPORTED ${TARGET} IMPORTED)
|
||||
set(LIBS "")
|
||||
get_target_property(LIBS_1 ${TARGET} INTERFACE_LINK_LIBRARIES)
|
||||
get_target_property(LIBS_2 ${TARGET} LINK_LIBRARIES)
|
||||
list(APPEND LIBS ${LIBS_1} ${LIBS_2})
|
||||
|
||||
set(LIB_FILES "")
|
||||
|
||||
foreach(LIB ${LIBS})
|
||||
if (TARGET ${LIB})
|
||||
list(FIND VISITED_TARGETS ${LIB} VISITED)
|
||||
if (${VISITED} EQUAL -1)
|
||||
# OLD: get_target_property(LIB_FILE ${LIB} LOCATION)
|
||||
# NEW:
|
||||
_get_link_libraries(LINK_LIB_FILES ${LIB})
|
||||
set(LIB_FILE ${LIB})
|
||||
list(APPEND LIB_FILES ${LINK_LIB_FILES})
|
||||
list(APPEND LIB_FILES ${LIB_FILE})
|
||||
endif()
|
||||
elseif(EXISTS ${LIB})
|
||||
set(LIB_FILE ${LIB})
|
||||
list(APPEND LIB_FILES ${LIB_FILE})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(VISITED_TARGETS ${VISITED_TARGETS} PARENT_SCOPE)
|
||||
set(${OUTPUT_LIST} ${LIB_FILES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
################################################################################
|
||||
|
||||
function(export_all_target_libs TARGET)
|
||||
# NOTE: get_target_property(CIRCUIT_LIB_LINK_LIBRARIES a_target LINK_LIBRARIES) is NOT transitive
|
||||
# This function will return eg: "$<TARGET_FILE:rust_cxx>;$<TARGET_FILE:circuit_lib>;"
|
||||
# b/c generator expression are evaluated LATER
|
||||
# cf https://stackoverflow.com/questions/59226127/cmake-generator-expression-how-to-get-target-file-property-on-list-of-targets
|
||||
set(ALL_LINK_LIBRARIES "")
|
||||
_get_link_libraries(ALL_LINK_LIBRARIES ${TARGET})
|
||||
|
||||
message(STATUS "ALL_LINK_LIBRARIES : ${ALL_LINK_LIBRARIES}")
|
||||
|
||||
set(ALL_LIBS "")
|
||||
set(ALL_EXTERNAL_LIBS "")
|
||||
# TODO move that back into get_link_libraries
|
||||
# NOTE: we MUST do it in 2 steps:
|
||||
# - collect all the LINK_LIBRARIES recursively
|
||||
# - loop on those and get their TARGET_FILE (if not INTERFACE_LIBRARY)
|
||||
# That is b/c in get_link_libraries a INTERFACE_LIBRARY CAN have link_libraries
|
||||
# but we CAN NOT evaluate generator expressions at this time.
|
||||
foreach(LIB ${ALL_LINK_LIBRARIES})
|
||||
# MUST skip INTERFACE else:
|
||||
# CMake Error at src/CMakeLists.txt:136 (add_custom_command):
|
||||
# Error evaluating generator expression:
|
||||
# $<TARGET_FILE:rust_cxx>
|
||||
# Target "rust_cxx" is not an executable or library.
|
||||
# SHARED_LIBRARY,INTERFACE_LIBRARY,STATIC_LIBRARY
|
||||
#
|
||||
if (TARGET ${LIB})
|
||||
get_target_property(LIB_TYPE ${LIB} TYPE)
|
||||
message(STATUS "LIB_TYPE : ${LIB} = ${LIB_TYPE}")
|
||||
|
||||
if(NOT ${LIB_TYPE} STREQUAL "INTERFACE_LIBRARY")
|
||||
set(LIB_FILE $<TARGET_FILE:${LIB}>)
|
||||
list(APPEND ALL_LIBS ${LIB_FILE})
|
||||
endif()
|
||||
elseif(EXISTS ${LIB})
|
||||
set(LIB_FILE ${LIB})
|
||||
message(STATUS "LIB_TYPE : ${LIB} = EXTERNAL")
|
||||
list(APPEND ALL_LIBS ${LIB_FILE})
|
||||
endif()
|
||||
endforeach() # LIB ${ALL_LIBS}
|
||||
|
||||
message(STATUS "ALL_LIBS : ${ALL_LIBS}")
|
||||
|
||||
# add_custom_command(ie echoing only to stdout) works but more difficult to get from build.rs
|
||||
# b/c when there is "ninja: no work to do" it will NOT echo on the console
|
||||
add_custom_command(
|
||||
TARGET ${TARGET}
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E echo ${ALL_LIBS} > ${CMAKE_CURRENT_BINARY_DIR}/cmake_generated_libs
|
||||
# OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/cmake_generated_libs
|
||||
VERBATIM
|
||||
)
|
||||
endfunction(export_all_target_libs)
|
@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "rust/cxx.h"
|
||||
#include <memory>
|
||||
|
||||
namespace tabby {
|
||||
|
||||
struct InferenceContext;
|
||||
|
||||
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
virtual rust::Vec<uint32_t> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature
|
||||
) const = 0;
|
||||
};
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str model_type,
|
||||
rust::Str device,
|
||||
rust::Slice<const int32_t> device_indices
|
||||
);
|
||||
} // namespace
|
@ -1,135 +0,0 @@
|
||||
#include "ctranslate2-bindings/include/ctranslate2.h"
|
||||
|
||||
#include "ctranslate2/translator.h"
|
||||
#include "ctranslate2/generator.h"
|
||||
|
||||
namespace tabby {
|
||||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
template <class Model, class Child>
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
protected:
|
||||
struct Options {
|
||||
size_t max_decoding_length;
|
||||
float sampling_temperature;
|
||||
};
|
||||
|
||||
public:
|
||||
rust::Vec<uint32_t> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature
|
||||
) const {
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
return process(
|
||||
std::move(context),
|
||||
std::move(callback),
|
||||
input_tokens,
|
||||
Options{max_decoding_length, sampling_temperature}
|
||||
);
|
||||
}
|
||||
|
||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||
auto impl = std::make_unique<Child>();
|
||||
impl->model_ = std::make_unique<Model>(loader);
|
||||
return impl;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const = 0;
|
||||
std::unique_ptr<Model> model_;
|
||||
};
|
||||
|
||||
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
||||
protected:
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::TranslationOptions x;
|
||||
x.max_decoding_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
rust::Vec<uint32_t> output_ids;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||
if (!stop) {
|
||||
output_ids.push_back(result.token_id);
|
||||
} else if (result.is_last) {
|
||||
output_ids.push_back(result.token_id);
|
||||
}
|
||||
return stop;
|
||||
};
|
||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||
return output_ids;
|
||||
}
|
||||
};
|
||||
|
||||
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
||||
protected:
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::GenerationOptions x;
|
||||
x.include_prompt_in_result = false;
|
||||
x.max_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
|
||||
rust::Vec<uint32_t> output_ids;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||
if (!stop) {
|
||||
output_ids.push_back(result.token_id);
|
||||
} else if (result.is_last) {
|
||||
output_ids.push_back(result.token_id);
|
||||
}
|
||||
return stop;
|
||||
};
|
||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||
return output_ids;
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str model_type,
|
||||
rust::Str device,
|
||||
rust::Slice<const int32_t> device_indices
|
||||
) {
|
||||
std::string model_type_str(model_type);
|
||||
std::string model_path_str(model_path);
|
||||
ctranslate2::models::ModelLoader loader(model_path_str);
|
||||
loader.device = ctranslate2::str_to_device(std::string(device));
|
||||
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
||||
loader.compute_type = ctranslate2::ComputeType::AUTO;
|
||||
|
||||
const size_t num_cpus = std::thread::hardware_concurrency();
|
||||
if (loader.device == ctranslate2::Device::CUDA) {
|
||||
// When device is cuda, set parallelism to be number of thread, capped to 4 to avoid VRAM oom.
|
||||
loader.num_replicas_per_device = std::min<int32_t>(num_cpus, 4);
|
||||
} else if (loader.device == ctranslate2::Device::CPU){
|
||||
// When device is cpu, adjust the number based on threads per replica.
|
||||
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
|
||||
loader.num_replicas_per_device = std::max<int32_t>(num_cpus / 4, 1);
|
||||
}
|
||||
|
||||
if (model_type_str == "AutoModelForCausalLM") {
|
||||
return DecoderImpl::create(loader);
|
||||
} else if (model_type_str == "AutoModelForSeq2SeqLM") {
|
||||
return EncoderDecoderImpl::create(loader);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace tabby
|
@ -1,178 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
use futures::stream::BoxStream;
|
||||
use tabby_inference::{
|
||||
decoding::{DecodingFactory, IncrementalDecoding},
|
||||
helpers, TextGeneration, TextGenerationOptions,
|
||||
};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc::{channel, Sender};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[cxx::bridge(namespace = "tabby")]
|
||||
mod ffi {
|
||||
extern "Rust" {
|
||||
type InferenceContext;
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("ctranslate2-bindings/include/ctranslate2.h");
|
||||
|
||||
type TextInferenceEngine;
|
||||
|
||||
fn create_engine(
|
||||
model_path: &str,
|
||||
model_type: &str,
|
||||
device: &str,
|
||||
device_indices: &[i32],
|
||||
) -> SharedPtr<TextInferenceEngine>;
|
||||
|
||||
fn inference(
|
||||
&self,
|
||||
context: Box<InferenceContext>,
|
||||
callback: fn(
|
||||
&mut InferenceContext,
|
||||
// step
|
||||
usize,
|
||||
// token_id
|
||||
u32,
|
||||
// token
|
||||
String,
|
||||
) -> bool,
|
||||
tokens: &[String],
|
||||
max_decoding_length: usize,
|
||||
sampling_temperature: f32,
|
||||
) -> Vec<u32>;
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for ffi::TextInferenceEngine {}
|
||||
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct CTranslate2EngineOptions {
|
||||
model_path: String,
|
||||
|
||||
model_type: String,
|
||||
|
||||
tokenizer_path: String,
|
||||
|
||||
device: String,
|
||||
|
||||
device_indices: Vec<i32>,
|
||||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
sender: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(
|
||||
sender: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
cancel: CancellationToken,
|
||||
) -> Self {
|
||||
InferenceContext {
|
||||
sender,
|
||||
decoding,
|
||||
cancel,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CTranslate2Engine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
decoding_factory: DecodingFactory,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
}
|
||||
|
||||
impl CTranslate2Engine {
|
||||
pub fn create(options: CTranslate2EngineOptions) -> Self where {
|
||||
let engine = ffi::create_engine(
|
||||
&options.model_path,
|
||||
&options.model_type,
|
||||
&options.device,
|
||||
&options.device_indices,
|
||||
);
|
||||
|
||||
return Self {
|
||||
engine,
|
||||
decoding_factory: DecodingFactory::default(),
|
||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TextGeneration for CTranslate2Engine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let s = self.generate_stream(prompt, options).await;
|
||||
helpers::stream_to_string(s).await
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
options: TextGenerationOptions,
|
||||
) -> BoxStream<String> {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let decoding = self.decoding_factory.create_incremental_decoding(
|
||||
self.tokenizer.clone(),
|
||||
truncate_tokens(encoding.get_ids(), options.max_input_length),
|
||||
options.language,
|
||||
);
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
let engine = self.engine.clone();
|
||||
let (sender, mut receiver) = channel::<String>(8);
|
||||
let context = InferenceContext::new(sender, decoding, cancel.clone());
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
context,
|
||||
inference_callback,
|
||||
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
);
|
||||
});
|
||||
|
||||
let s = stream! {
|
||||
let _guard = cancel.drop_guard();
|
||||
while let Some(text) = receiver.recv().await {
|
||||
yield text;
|
||||
}
|
||||
};
|
||||
Box::pin(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_tokens<T>(tokens: &[T], max_length: usize) -> &[T] {
|
||||
if max_length < tokens.len() {
|
||||
let start = tokens.len() - max_length;
|
||||
&tokens[start..]
|
||||
} else {
|
||||
tokens
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_callback(
|
||||
context: &mut InferenceContext,
|
||||
_step: usize,
|
||||
token_id: u32,
|
||||
_token: String,
|
||||
) -> bool {
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else if let Some(new_text) = context.decoding.next_token(token_id) {
|
||||
let _ = context.sender.blocking_send(new_text);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
@ -16,7 +16,6 @@ async-trait = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
derive_builder = { workspace = true }
|
||||
tokenizers = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
futures.workspace = true
|
||||
async-stream.workspace = true
|
||||
|
@ -4,16 +4,15 @@
|
||||
#include <memory>
|
||||
|
||||
namespace llama {
|
||||
struct StepOutput;
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
|
||||
virtual void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) = 0;
|
||||
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) = 0;
|
||||
virtual void stop_request(uint32_t request_id) = 0;
|
||||
virtual rust::Vec<uint32_t> step() = 0;
|
||||
|
||||
virtual uint32_t eos_token_id() const = 0;
|
||||
virtual rust::Vec<StepOutput> step() = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
|
||||
|
@ -8,6 +8,8 @@
|
||||
#include <ggml.h>
|
||||
#include <llama.h>
|
||||
|
||||
#include "llama-cpp-bindings/src/lib.rs.h"
|
||||
|
||||
namespace llama {
|
||||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
@ -27,20 +29,56 @@ constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||
|
||||
struct Request {
|
||||
Request(size_t request_id, rust::Slice<const uint32_t> input_token_ids) :
|
||||
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
||||
id(request_id),
|
||||
tokens(input_token_ids.begin(), input_token_ids.end()) {
|
||||
}
|
||||
|
||||
size_t id = -1;
|
||||
uint32_t id = -1;
|
||||
llama_seq_id seq_id = -1;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
size_t i_batch = -1;
|
||||
size_t n_past = 0;
|
||||
|
||||
int32_t multibyte_pending = 0;
|
||||
std::string generated_text;
|
||||
};
|
||||
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
||||
std::vector<char> result(8, 0);
|
||||
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
|
||||
return std::string(result.data(), result.size());
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const rust::Str & text,
|
||||
bool add_bos,
|
||||
bool special) {
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + add_bos;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||
|
||||
@ -56,15 +94,20 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
llama_batch_free(batch_);
|
||||
}
|
||||
|
||||
void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) override {
|
||||
pending_requests_.push_back(Request(request_id, input_token_ids));
|
||||
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override {
|
||||
auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
|
||||
if (tokens.size() > max_input_length) {
|
||||
int start = tokens.size() - max_input_length;
|
||||
tokens = std::vector<llama_token>(tokens.begin() + start, tokens.end());
|
||||
}
|
||||
pending_requests_.push_back(Request(request_id, tokens));
|
||||
}
|
||||
|
||||
void stop_request(uint32_t request_id) override {
|
||||
stopped_requests_.insert(request_id);
|
||||
}
|
||||
|
||||
rust::Vec<uint32_t> step() override {
|
||||
rust::Vec<StepOutput> step() override {
|
||||
auto* ctx = ctx_.get();
|
||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
@ -123,28 +166,29 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
request.i_batch = batch_.n_tokens - 1;
|
||||
}
|
||||
|
||||
rust::Vec<uint32_t> result;
|
||||
result.reserve(requests_.size() * 2);
|
||||
rust::Vec<StepOutput> result;
|
||||
result.reserve(requests_.size());
|
||||
|
||||
// Decode tokens in chunks
|
||||
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
|
||||
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch_.token + i,
|
||||
nullptr,
|
||||
batch_.pos + i,
|
||||
batch_.n_seq_id + i,
|
||||
batch_.seq_id + i,
|
||||
batch_.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch_.token + i,
|
||||
nullptr,
|
||||
batch_.pos + i,
|
||||
batch_.n_seq_id + i,
|
||||
batch_.seq_id + i,
|
||||
batch_.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
if (ret != 0) {
|
||||
throw std::runtime_error("Failed to eval");
|
||||
}
|
||||
|
||||
const auto eos_id = llama_token_eos(llama_get_model(ctx));
|
||||
for (auto& request : requests_) {
|
||||
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
|
||||
continue;
|
||||
@ -159,18 +203,44 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
request.tokens.clear();
|
||||
request.tokens.push_back(next_token);
|
||||
|
||||
result.push_back(request.id);
|
||||
result.push_back(next_token);
|
||||
const auto token_str = llama_token_to_piece(ctx, next_token);
|
||||
request.generated_text += token_str;
|
||||
|
||||
// FIXME: Hack for codellama to simplify tabby's implementation.
|
||||
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
|
||||
|
||||
if (request.multibyte_pending > 0) {
|
||||
request.multibyte_pending -= token_str.size();
|
||||
} else if (token_str.size() == 1) {
|
||||
const char c = token_str[0];
|
||||
// 2-byte characters: 110xxxxx 10xxxxxx
|
||||
if ((c & 0xE0) == 0xC0) {
|
||||
request.multibyte_pending = 1;
|
||||
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
||||
}
|
||||
else if ((c & 0xF0) == 0xE0) {
|
||||
request.multibyte_pending = 2;
|
||||
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
} else if ((c & 0xF8) == 0xF0) {
|
||||
request.multibyte_pending = 3;
|
||||
}
|
||||
else {
|
||||
request.multibyte_pending = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (request.multibyte_pending == 0) {
|
||||
rust::String generated_text = is_eos ? "" : request.generated_text;
|
||||
result.push_back({request.id, generated_text});
|
||||
|
||||
request.generated_text.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t eos_token_id() const override {
|
||||
return llama_token_eos(llama_get_model(ctx_.get()));
|
||||
}
|
||||
|
||||
private:
|
||||
owned<llama_model> model_;
|
||||
owned<llama_context> ctx_;
|
||||
|
@ -7,10 +7,9 @@ use derive_builder::Builder;
|
||||
use ffi::create_engine;
|
||||
use futures::{lock::Mutex, stream::BoxStream};
|
||||
use tabby_inference::{
|
||||
decoding::{DecodingFactory, IncrementalDecoding},
|
||||
decoding::{StopCondition, StopConditionFactory},
|
||||
helpers, TextGeneration, TextGenerationOptions,
|
||||
};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::{
|
||||
sync::mpsc::{channel, Sender},
|
||||
task::yield_now,
|
||||
@ -18,6 +17,11 @@ use tokio::{
|
||||
|
||||
#[cxx::bridge(namespace = "llama")]
|
||||
mod ffi {
|
||||
struct StepOutput {
|
||||
request_id: u32,
|
||||
text: String,
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("llama-cpp-bindings/include/engine.h");
|
||||
|
||||
@ -28,12 +32,11 @@ mod ffi {
|
||||
fn add_request(
|
||||
self: Pin<&mut TextInferenceEngine>,
|
||||
request_id: u32,
|
||||
input_token_ids: &[u32],
|
||||
prompt: &str,
|
||||
max_input_length: usize,
|
||||
);
|
||||
fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32);
|
||||
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<u32>>;
|
||||
|
||||
fn eos_token_id(&self) -> u32;
|
||||
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<StepOutput>>;
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,26 +45,22 @@ unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||
|
||||
struct InferenceRequest {
|
||||
tx: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
stop_condition: StopCondition,
|
||||
}
|
||||
|
||||
struct AsyncTextInferenceEngine {
|
||||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
decoding_factory: DecodingFactory,
|
||||
stop_condition_factory: StopConditionFactory,
|
||||
requests: Mutex<HashMap<u32, InferenceRequest>>,
|
||||
|
||||
next_request_id: Mutex<u32>,
|
||||
eos_token_id: u32,
|
||||
}
|
||||
|
||||
impl AsyncTextInferenceEngine {
|
||||
fn create(engine: UniquePtr<ffi::TextInferenceEngine>, tokenizer: Tokenizer) -> Self {
|
||||
fn create(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
|
||||
Self {
|
||||
eos_token_id: engine.eos_token_id(),
|
||||
engine: Mutex::new(engine),
|
||||
tokenizer: Arc::new(tokenizer),
|
||||
decoding_factory: DecodingFactory::default(),
|
||||
stop_condition_factory: StopConditionFactory::default(),
|
||||
requests: Mutex::new(HashMap::new()),
|
||||
next_request_id: Mutex::new(0),
|
||||
}
|
||||
@ -79,18 +78,15 @@ impl AsyncTextInferenceEngine {
|
||||
panic!("Failed to evaluation");
|
||||
};
|
||||
|
||||
for i in (0..result.len()).step_by(2) {
|
||||
let request_id = result[i];
|
||||
let token_id = result[i + 1];
|
||||
|
||||
let InferenceRequest { tx, decoding } = requests.get_mut(&request_id).unwrap();
|
||||
for ffi::StepOutput { request_id, text } in result {
|
||||
let mut stopped = false;
|
||||
let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap();
|
||||
|
||||
if tx.is_closed() || token_id == self.eos_token_id {
|
||||
if tx.is_closed() || text.is_empty() {
|
||||
// Cancelled by client side or hit eos.
|
||||
stopped = true;
|
||||
} else if let Some(new_text) = decoding.next_token(token_id) {
|
||||
match tx.send(new_text).await {
|
||||
} else if !stop_condition.should_stop(&text) {
|
||||
match tx.send(text).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => stopped = true,
|
||||
}
|
||||
@ -111,25 +107,21 @@ impl AsyncTextInferenceEngine {
|
||||
prompt: &str,
|
||||
options: TextGenerationOptions,
|
||||
) -> BoxStream<String> {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
let decoding = self.decoding_factory.create_incremental_decoding(
|
||||
self.tokenizer.clone(),
|
||||
input_token_ids,
|
||||
options.language,
|
||||
);
|
||||
let stop_condition = self.stop_condition_factory.create(prompt, options.language);
|
||||
|
||||
let (tx, mut rx) = channel::<String>(4);
|
||||
{
|
||||
let mut engine = self.engine.lock().await;
|
||||
let engine = engine.as_mut().unwrap();
|
||||
|
||||
let mut request_id = self.next_request_id.lock().await;
|
||||
self.requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(*request_id, InferenceRequest { tx, decoding });
|
||||
engine.add_request(*request_id, input_token_ids);
|
||||
.insert(*request_id, InferenceRequest { tx, stop_condition });
|
||||
engine
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.add_request(*request_id, prompt, options.max_input_length);
|
||||
|
||||
// 2048 should be large enough to avoid collision.
|
||||
*request_id = (*request_id + 1) % 2048;
|
||||
@ -155,7 +147,6 @@ impl AsyncTextInferenceEngine {
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct LlamaTextGenerationOptions {
|
||||
model_path: String,
|
||||
tokenizer_path: String,
|
||||
use_gpu: bool,
|
||||
}
|
||||
|
||||
@ -169,9 +160,8 @@ impl LlamaTextGeneration {
|
||||
if engine.is_null() {
|
||||
panic!("Unable to load model: {}", options.model_path);
|
||||
}
|
||||
let tokenizer = Tokenizer::from_file(&options.tokenizer_path).unwrap();
|
||||
let ret = LlamaTextGeneration {
|
||||
engine: Arc::new(AsyncTextInferenceEngine::create(engine, tokenizer)),
|
||||
engine: Arc::new(AsyncTextInferenceEngine::create(engine)),
|
||||
};
|
||||
ret.start_background_job();
|
||||
ret
|
||||
@ -203,12 +193,3 @@ impl TextGeneration for LlamaTextGeneration {
|
||||
self.engine.generate_stream(prompt, options).await
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] {
|
||||
if max_length < tokens.len() {
|
||||
let start = tokens.len() - max_length;
|
||||
&tokens[start..]
|
||||
} else {
|
||||
tokens
|
||||
}
|
||||
}
|
||||
|
@ -78,14 +78,6 @@ impl ModelDir {
|
||||
self.path_string("tabby.json")
|
||||
}
|
||||
|
||||
pub fn tokenizer_file(&self) -> String {
|
||||
self.path_string("tokenizer.json")
|
||||
}
|
||||
|
||||
pub fn ctranslate2_dir(&self) -> String {
|
||||
self.path_string("ctranslate2")
|
||||
}
|
||||
|
||||
pub fn ggml_q8_0_file(&self) -> String {
|
||||
self.path_string("ggml/q8_0.gguf")
|
||||
}
|
||||
|
@ -29,27 +29,8 @@ impl Downloader {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn download_ctranslate2_files(&self) -> Result<()> {
|
||||
let files = vec![
|
||||
("tabby.json", true),
|
||||
("tokenizer.json", true),
|
||||
("ctranslate2/vocabulary.txt", false),
|
||||
("ctranslate2/shared_vocabulary.txt", false),
|
||||
("ctranslate2/vocabulary.json", false),
|
||||
("ctranslate2/shared_vocabulary.json", false),
|
||||
("ctranslate2/config.json", true),
|
||||
("ctranslate2/model.bin", true),
|
||||
];
|
||||
|
||||
self.download_files(&files).await
|
||||
}
|
||||
|
||||
pub async fn download_ggml_files(&self) -> Result<()> {
|
||||
let files = vec![
|
||||
("tabby.json", true),
|
||||
("tokenizer.json", true),
|
||||
("ggml/q8_0.v2.gguf", true),
|
||||
];
|
||||
let files = vec![("tabby.json", true), ("ggml/q8_0.v2.gguf", true)];
|
||||
self.download_files(&files).await
|
||||
}
|
||||
|
||||
|
@ -12,5 +12,4 @@ dashmap = "5.5.3"
|
||||
derive_builder = "0.12.0"
|
||||
futures = { workspace = true }
|
||||
regex.workspace = true
|
||||
tokenizers.workspace = true
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
|
@ -1,11 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use regex::Regex;
|
||||
use tabby_common::languages::Language;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
pub struct DecodingFactory {
|
||||
pub struct StopConditionFactory {
|
||||
stop_regex_cache: DashMap<String, Regex>,
|
||||
}
|
||||
|
||||
@ -16,7 +13,7 @@ where
|
||||
s.into().chars().rev().collect()
|
||||
}
|
||||
|
||||
impl Default for DecodingFactory {
|
||||
impl Default for StopConditionFactory {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stop_regex_cache: DashMap::new(),
|
||||
@ -24,14 +21,9 @@ impl Default for DecodingFactory {
|
||||
}
|
||||
}
|
||||
|
||||
impl DecodingFactory {
|
||||
pub fn create_incremental_decoding(
|
||||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
input_token_ids: &[u32],
|
||||
language: &'static Language,
|
||||
) -> IncrementalDecoding {
|
||||
IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids)
|
||||
impl StopConditionFactory {
|
||||
pub fn create(&self, text: &str, language: &'static Language) -> StopCondition {
|
||||
StopCondition::new(self.get_re(language), text)
|
||||
}
|
||||
|
||||
fn get_re(&self, language: &'static Language) -> Option<Regex> {
|
||||
@ -62,68 +54,31 @@ fn create_stop_regex(stop_words: Vec<String>) -> Regex {
|
||||
Regex::new(®ex_string).expect("Failed to create regex")
|
||||
}
|
||||
|
||||
pub struct IncrementalDecoding {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
pub struct StopCondition {
|
||||
stop_re: Option<Regex>,
|
||||
|
||||
token_ids: Vec<u32>,
|
||||
prefix_offset: usize,
|
||||
read_offset: usize,
|
||||
|
||||
reversed_text: String,
|
||||
}
|
||||
|
||||
impl IncrementalDecoding {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
||||
let text = tokenizer
|
||||
.decode(input_token_ids, /* skip_special_token = */ true)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
impl StopCondition {
|
||||
pub fn new(stop_re: Option<Regex>, text: &str) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
stop_re,
|
||||
token_ids: input_token_ids.to_owned(),
|
||||
prefix_offset: 0,
|
||||
read_offset: input_token_ids.len(),
|
||||
reversed_text: reverse(text),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_token(&mut self, token_id: u32) -> Option<String> {
|
||||
let skip_special_token = true;
|
||||
self.token_ids.push(token_id);
|
||||
|
||||
let prefix_text = self
|
||||
.tokenizer
|
||||
.decode(
|
||||
&self.token_ids[self.prefix_offset..self.read_offset],
|
||||
skip_special_token,
|
||||
)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
|
||||
let new_text = self
|
||||
.tokenizer
|
||||
.decode(&self.token_ids[self.prefix_offset..], skip_special_token)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
|
||||
let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('<27>') {
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_ids.len();
|
||||
&new_text[prefix_text.len()..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
pub fn should_stop(&mut self, new_text: &str) -> bool {
|
||||
if !new_text.is_empty() {
|
||||
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||||
|
||||
if let Some(re) = &self.stop_re {
|
||||
if re.is_match(&self.reversed_text) {
|
||||
return None;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(new_text.to_owned())
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,6 @@ pub struct EngineInfo {
|
||||
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||
.model_path(model_dir.ggml_q8_0_v2_file())
|
||||
.tokenizer_path(model_dir.tokenizer_file())
|
||||
.use_gpu(device.ggml_use_gpu())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
Loading…
Reference in New Issue
Block a user