@built-in-ai/transformers-js
Why?
Why use @built-in-ai/transformers-js instead of just Transformers.js
The Problem
Running Hugging Face models in the browser with just @huggingface/transformers requires significant boilerplate code. You need to handle:
- Web Worker setup and communication
- Progress tracking
- Streaming token generation
- Interrupt handling
- Error states
- State management
Before: Just Transformers.js
Here's what a typical Web Worker looks like with just @huggingface/transformers:
import {
AutoTokenizer,
AutoModelForCausalLM,
TextStreamer,
InterruptableStoppingCriteria,
} from "@huggingface/transformers"
async function check() {
try {
const adapter = await navigator.gpu.requestAdapter()
if (!adapter) {
throw new Error("WebGPU is not supported (no adapter found)")
}
} catch (e) {
self.postMessage({
status: "error",
data: e.toString(),
})
}
}
class TextGenerationPipeline {
static model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
static async getInstance(progress_callback = null) {
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
progress_callback,
})
this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
dtype: "q4f16",
device: "webgpu",
progress_callback,
})
return Promise.all([this.tokenizer, this.model])
}
}
const stopping_criteria = new InterruptableStoppingCriteria()
let past_key_values_cache = null
async function generate(messages) {
const [tokenizer, model] = await TextGenerationPipeline.getInstance()
const inputs = tokenizer.apply_chat_template(messages, {
add_generation_prompt: true,
return_dict: true,
})
let startTime
let numTokens = 0
let tps
const token_callback_function = () => {
startTime ??= performance.now()
if (numTokens++ > 0) {
tps = (numTokens / (performance.now() - startTime)) * 1000
}
}
const callback_function = (output) => {
self.postMessage({
status: "update",
output,
tps,
numTokens,
})
}
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function,
token_callback_function,
})
self.postMessage({ status: "start" })
const { past_key_values, sequences } = await model.generate({
...inputs,
past_key_values: past_key_values_cache,
max_new_tokens: 1024,
streamer,
stopping_criteria,
return_dict_in_generate: true,
})
past_key_values_cache = past_key_values
const decoded = tokenizer.batch_decode(sequences, {
skip_special_tokens: true,
})
self.postMessage({
status: "complete",
output: decoded,
})
}
async function load() {
self.postMessage({
status: "loading",
data: "Loading model...",
})
const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
self.postMessage(x)
})
self.postMessage({
status: "loading",
data: "Compiling shaders and warming up model...",
})
const inputs = tokenizer("a")
await model.generate({ ...inputs, max_new_tokens: 1 })
self.postMessage({ status: "ready" })
}
self.addEventListener("message", async (e) => {
const { type, data } = e.data
switch (type) {
case "check":
check()
break
case "load":
load()
break
case "generate":
stopping_criteria.reset()
generate(data)
break
case "interrupt":
stopping_criteria.interrupt()
break
case "reset":
past_key_values_cache = null
stopping_criteria.reset()
break
}
})That's ~130 lines of complex code just for the worker. And you still need to write the component code to handle all the message passing, state management, etc.!
After: With @built-in-ai/transformers-js
import { TransformersJSWorkerHandler } from "@built-in-ai/transformers-js"
const handler = new TransformersJSWorkerHandler()
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg)
}6 lines of code.
Component Code Comparison
Before: Using just Transformers.js
function App() {
const worker = useRef(null);
const textareaRef = useRef(null);
const chatContainerRef = useRef(null);
const [status, setStatus] = useState(null);
const [error, setError] = useState(null);
const [loadingMessage, setLoadingMessage] = useState("");
const [progressItems, setProgressItems] = useState([]);
const [isRunning, setIsRunning] = useState(false);
const [input, setInput] = useState("");
const [messages, setMessages] = useState([]);
const [tps, setTps] = useState(null);
const [numTokens, setNumTokens] = useState(null);
function onEnter(message) {
setMessages((prev) => [...prev, { role: "user", content: message }]);
setTps(null);
setIsRunning(true);
setInput("");
}
function onInterrupt() {
worker.current.postMessage({ type: "interrupt" });
}
useEffect(() => {
if (!worker.current) {
worker.current = new Worker(new URL("./worker.js", import.meta.url), {
type: "module",
});
worker.current.postMessage({ type: "check" });
}
const onMessageReceived = (e) => {
switch (e.data.status) {
case "loading":
setStatus("loading");
setLoadingMessage(e.data.data);
break;
case "initiate":
setProgressItems((prev) => [...prev, e.data]);
break;
case "progress":
setProgressItems((prev) =>
prev.map((item) => {
if (item.file === e.data.file) {
return { ...item, ...e.data };
}
return item;
}),
);
break;
case "done":
setProgressItems((prev) =>
prev.filter((item) => item.file !== e.data.file),
);
break;
case "ready":
setStatus("ready");
break;
case "start":
setMessages((prev) => [
...prev,
{ role: "assistant", content: "" },
]);
break;
case "update":
const { output, tps, numTokens } = e.data;
setTps(tps);
setNumTokens(numTokens);
setMessages((prev) => {
const cloned = [...prev];
const last = cloned.at(-1);
cloned[cloned.length - 1] = {
...last,
content: last.content + output,
};
return cloned;
});
break;
case "complete":
setIsRunning(false);
break;
case "error":
setError(e.data.data);
break;
}
};
worker.current.addEventListener("message", onMessageReceived);
worker.current.addEventListener("error", onErrorReceived);
return () => {
worker.current.removeEventListener("message", onMessageReceived);
worker.current.removeEventListener("error", onErrorReceived);
};
}, []);
// ... rest of the component
}~100+ lines of state management and event handling.
After: With @built-in-ai/transformers-js
import { useChat } from "ai/react";
import { transformersJS, TransformersUIMessage } from "@built-in-ai/transformers-js";
const model = transformersJS("HuggingFaceTB/SmolLM2-360M-Instruct", {
device: "webgpu",
dtype: "q4",
worker: new Worker(new URL("./worker.ts", import.meta.url), {
type: "module",
}),
});
function App() {
const { error, status, sendMessage, messages, stop } =
useChat<TransformersUIMessage>({
transport: new TransformersChatTransport(model),
});
// ... your UI
}Just a few lines. Declarative. All the complexity is abstracted away, giving you control over the UX instead.