
Mastering Amazon Bedrock Custom Models Fine-tuning (Part 1): Getting started with Fine-tuning
Explore fine-tuning and RAG techniques for foundation models. Guide on choosing the appropriate approach based on use case. Demonstrate fine-tuning process, including data preparation, hyperparameter tuning, and evaluation for foundation models.
- Brief understanding of fine-tuning
- Brief background on RAG (Retrieval-Augmented Generation)
- The criteria for choosing between fine-tuning and RAG
- Getting started with fine-tuning
- Specialized Tasks: Fine-tuning is ideal for narrow, specialized tasks where precision and performance are paramount. For example, if you're developing a medical diagnosis model, fine-tuning on a curated dataset of medical records will yield highly accurate results.
- High Performance and Low Latency: If your application demands low latency and high throughput, fine-tuning is the better option. Fine-tuned models do not require an additional retrieval step, making them faster in inference.
- Curated Datasets: If you have access to a well-defined, labeled, and curated dataset relevant to your specific task, fine-tuning can leverage this data to optimize performance.
- Quality of Prediction: For tasks where the quality and accuracy of predictions are critical, fine-tuning allows you to tailor the model closely to your specific requirements.
- High Performance: Optimized for specific tasks, leading to better accuracy and performance.
- Low Latency: Faster inference times as there is no need for an additional retrieval step.
- Task Specificity: Tailored to perform exceptionally well on the specific task it is trained on.
- Cost: Fine-tuning requires substantial initial investment in training, including preprocessing costs for scraping, transforming, and cleaning the data.
- Loses Generalization: Fine-tuned models are highly specialized, meaning different models are needed for different tasks.
- Not Ideal for Frequently Changing Data: As the model is trained on a static dataset, it does not adapt well to dynamic data environments.
- Frequently Changing Data: RAG is preferred when the data changes frequently, such as in news agencies or media outlets. The model can retrieve up-to-date information without retraining.
- Broad Domain Knowledge: If your application covers a wide range of topics or domains, RAG can efficiently handle the diversity by retrieving relevant information dynamically.
- Limited Labeled Data: RAG is advantageous when you lack a substantial labeled dataset. It uses pre-trained models and retrieves context from external sources, reducing the need for extensive training data.
- Cost and Time Efficiency: RAG can be implemented quickly with lower initial costs since it avoids the extensive training process.
- Flexibility: Handles a wide variety of tasks by retrieving relevant information on-the-fly.
- Lower Initial Costs: Avoids the costs associated with training, making it more accessible and faster to deploy.
- Retains Generalization: The base model remains unaltered, maintaining its ability to generalize across different tasks.
- Slower Inference: The retrieval step adds latency, making RAG slower compared to fine-tuned models.
- Complexity: Involves multiple components, such as a vector database, embedding models, and document loaders, which can complicate the system.
- Higher Token Usage: Requires parsing the query along with the context, leading to increased token usage per prompt.
- Performance Sensitivity: If your application demands high performance, low latency, and high-quality predictions for a narrow domain, fine-tuning is the recommended approach.
- Dynamic Data Environments: For applications dealing with frequently updated information or broad domain knowledge, RAG is often the more practical and cost-effective solution.
1
2
3
4
5
6
7
8
9
10
11
from datasets import load_dataset
dolly_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
# To train for question answering/information extraction, you can replace the assertion in next line to example["category"] == "closed_qa"/"information_extraction".
summarization_dataset = dolly_dataset.filter(lambda example: example["category"] == "summarization")
summarization_dataset = summarization_dataset.remove_columns("category")
# We split the dataset into two where test data is used to evaluate at the end.
train_and_test_dataset = summarization_dataset.train_test_split(test_size=0.1)
train_and_test_dataset["test"][0]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from sagemaker.jumpstart.estimator import JumpStartEstimator
estimator = JumpStartEstimator(
model_id=model_id,
model_version=model_version,
instance_type="ml.g5.12xlarge",
instance_count=2,
environment={"accept_eula": "true"}
)
# By default, instruction tuning is set to false. Thus, to use instruction tuning dataset you use
estimator.set_hyperparameters(instruction_tuned="True",
epoch="5",
max_input_length="1024")
estimator.fit({"training": train_data_location})
Any opinions in this post are those of the individual author and may not reflect the opinions of AWS.