AWS Logo
Menu
Enforcing a completion schema on Llama 3.2 11B Vision using Hugging Face TGI images on Amazon SageMaker

Enforcing a completion schema on Llama 3.2 11B Vision using Hugging Face TGI images on Amazon SageMaker

Generative models can return completions that may not always comply with a user-defined schema, even after using prompt engineering. This blog post demonstrates a method to ensure schema-compliant outputs using the guidance feature in Hugging Face TGI images on Amazon SageMaker. This feature can enforce a JSON grammar, eliminating the need for human intervention in validation or correction of inconsistent completions.

Felipe Lopez
Amazon Employee
Published Nov 20, 2024
Generative models are inherently non-deterministic due to their reliance on probabilistic sampling mechanisms. When generating responses, language models sample from a probability distribution over possible next tokens, which introduces variability in the response even for the same prompt. This non-deterministic behavior can be problematic when a model is asked to produce structured text that must adhere strictly to predefined schemas. That would be the case if, for example, the generated text were to be passed directly to a downstream application or stored in a database. In those scenarios, inconsistent or malformed responses can disrupt workflows and cause errors in downstream applications.
Then, what can we do when we need a model to generate text that complies with a specific schema? How can we guide the model to the type of completion we want to generate?
There are two types of approaches:
  1. Prompt engineering: We can enhance the prompt with detailed information on the schema to be followed, and this prompt may include examples of schema-compliant completions in the case of few-shot prompting. Although powerful and often successful, this is not a guarantee that the generated text will be compliant with the schema.
  2. Token filtering: We can modify the token generation process to allow the model to generate only those tokens that would result in completions that comply with the schema. Unlike prompt engineering, token filtering is guaranteed to result in schema-compliant completions. Since this approach modifies the generation of new tokens, it requires access to the output logits of the model and cannot be implemented with models deployed with Amazon Bedrock or similar API-based solutions.

Token filtering with lm-format-enforcer

There are several tools available for token filtering. One of my favorite ones is the lm-format-enforcer library, which builds a function that will filter the set of allowed tokens to run constrained text generation with the transformers library.
Let's explore this tool with a quick example. Image you have thousands of receipts to process, and you need to extract the name of the business, its location, the date of the transaction, the sub-total and total amounts. If the goal is to run analytics, you will need all of the entries to be well-formed, and that need for a guarantee could make you decide to spend several hours doing this manually (please, don't!). Or you could use a multi-modal model like Llama 3.2 with lm-format-enforcer to do this for you.
sample-receipt
The lm-format-enforcer library comes with a great example of enforcing a JSON schema when invoking Llama 3.2 11B Vision. In the hypothetical scenario of processing receipts to extract the desired information, the only thing we would have to change is the definition of schema and the filtering function prefix_fun.
And we can then pass the prefix_func function to the model during generation in the same way as in the Llama 3.2 Vision example. This slight modification forces the model to return only the desired JSON in a format that is syntactically correct.
This can be carried over when deploying a model with Amazon SageMaker. For example, we could change the inference function in the entry point file used when deploying a predictor ... or we could take an even easier approach.

Using guidance with Hugging Face TGI images on Amazon SageMaker

Text Generation Inference (TGI) is a toolkit for serving language models that enables high-performance text generation for popular open-source LLMs, including Llama 3.2. The TGI library is already integrated with the SageMaker SDK and that allows us to deploy high-performance containers on SageMaker to host our models.
Below are the cells required to deploy Llama 3.2 11B Vision Instruct on Amazon SageMaker, and to configure a guidance to enforce a JSON schema. The code is based on Philipp Schmid's blog on deploying Llama 3.2 Vision on SageMaker, which has been modified to enforce a JSON schema.

Install dependencies

This example will require the sagemaker and pydantic libraries.

Define session and role

Find TGI image

We are going to use the SageMaker SDK to automatically select the URI for the deep learning container (DLC) for TGI inference containers in the version 2.3.1, which supports Llama 3.2 models.

Enter Hugging Face token

Llama 3.2 11B Vision is a gated model, so you will be asked for your Hugging Face token. If you have not been granted access to Llama 3 models yet, sign in to your Hugging Face account, read the Meta Llama 3 Acceptable Use Policy, and submit your contact information to be granted access. This process might take a couple of hours.

Create Hugging Face model with the image URI

Then we are going to create a HuggingFaceModel just like in the sample blog and we are going to deploy it with a SageMaker real-time endpoint backed by 1 ml.g6.12xlarge instance.

Define the completion schema using pydantic

This is where the difference comes. The first step is to create a schema using the BaseModel class from pydantic. We are going to pass this schema both to the prompt and to the generation step to limit the set of allowed tokens. Hugging Face has great information on how to constrain token generation using grammars.
Include the schema in the prompt
Invoke a prediction passing the schema as a JSON grammar
And you're done! The model returned a JSON object whose properties are guaranteed to follow our user-defined schema.
 

Any opinions in this post are those of the individual author and may not reflect the opinions of AWS.

Comments