Enforcing a completion schema on Llama 3.2 11B Vision using Hugging Face TGI images on Amazon SageMaker
Generative models can return completions that may not always comply with a user-defined schema, even after using prompt engineering. This blog post demonstrates a method to ensure schema-compliant outputs using the guidance feature in Hugging Face TGI images on Amazon SageMaker. This feature can enforce a JSON grammar, eliminating the need for human intervention in validation or correction of inconsistent completions.
Felipe Lopez
Amazon Employee
Published Nov 20, 2024
Generative models are inherently non-deterministic due to their reliance on probabilistic sampling mechanisms. When generating responses, language models sample from a probability distribution over possible next tokens, which introduces variability in the response even for the same prompt. This non-deterministic behavior can be problematic when a model is asked to produce structured text that must adhere strictly to predefined schemas. That would be the case if, for example, the generated text were to be passed directly to a downstream application or stored in a database. In those scenarios, inconsistent or malformed responses can disrupt workflows and cause errors in downstream applications.
Then, what can we do when we need a model to generate text that complies with a specific schema? How can we guide the model to the type of completion we want to generate?
There are two types of approaches:
- Prompt engineering: We can enhance the prompt with detailed information on the schema to be followed, and this prompt may include examples of schema-compliant completions in the case of few-shot prompting. Although powerful and often successful, this is not a guarantee that the generated text will be compliant with the schema.
- Token filtering: We can modify the token generation process to allow the model to generate only those tokens that would result in completions that comply with the schema. Unlike prompt engineering, token filtering is guaranteed to result in schema-compliant completions. Since this approach modifies the generation of new tokens, it requires access to the output logits of the model and cannot be implemented with models deployed with Amazon Bedrock or similar API-based solutions.
There are several tools available for token filtering. One of my favorite ones is the
lm-format-enforcer
library, which builds a function that will filter the set of allowed tokens to run constrained text generation with the transformers library. Let's explore this tool with a quick example. Image you have thousands of receipts to process, and you need to extract the name of the business, its location, the date of the transaction, the sub-total and total amounts. If the goal is to run analytics, you will need all of the entries to be well-formed, and that need for a guarantee could make you decide to spend several hours doing this manually (please, don't!). Or you could use a multi-modal model like Llama 3.2 with
lm-format-enforcer
to do this for you.The
lm-format-enforcer
library comes with a great example of enforcing a JSON schema when invoking Llama 3.2 11B Vision. In the hypothetical scenario of processing receipts to extract the desired information, the only thing we would have to change is the definition of schema
and the filtering function prefix_fun
.And we can then pass the
prefix_func
function to the model during generation in the same way as in the Llama 3.2 Vision example. This slight modification forces the model to return only the desired JSON in a format that is syntactically correct.This can be carried over when deploying a model with Amazon SageMaker. For example, we could change the inference function in the entry point file used when deploying a predictor ... or we could take an even easier approach.
Text Generation Inference (TGI) is a toolkit for serving language models that enables high-performance text generation for popular open-source LLMs, including Llama 3.2. The TGI library is already integrated with the SageMaker SDK and that allows us to deploy high-performance containers on SageMaker to host our models.
Below are the cells required to deploy Llama 3.2 11B Vision Instruct on Amazon SageMaker, and to configure a guidance to enforce a JSON schema. The code is based on Philipp Schmid's blog on deploying Llama 3.2 Vision on SageMaker, which has been modified to enforce a JSON schema.
This example will require the
sagemaker
and pydantic
libraries.We are going to use the SageMaker SDK to automatically select the URI for the deep learning container (DLC) for TGI inference containers in the version
2.3.1
, which supports Llama 3.2 models.Llama 3.2 11B Vision is a gated model, so you will be asked for your Hugging Face token. If you have not been granted access to Llama 3 models yet, sign in to your Hugging Face account, read the Meta Llama 3 Acceptable Use Policy, and submit your contact information to be granted access. This process might take a couple of hours.
Then we are going to create a
HuggingFaceModel
just like in the sample blog and we are going to deploy it with a SageMaker real-time endpoint backed by 1 ml.g6.12xlarge
instance.This is where the difference comes. The first step is to create a schema using the
BaseModel
class from pydantic
. We are going to pass this schema both to the prompt and to the generation step to limit the set of allowed tokens. Hugging Face has great information on how to constrain token generation using grammars.And you're done! The model returned a JSON object whose properties are guaranteed to follow our user-defined schema.
Any opinions in this post are those of the individual author and may not reflect the opinions of AWS.