
A new way to implement Product Search using Amazon Titan Multimodal Embeddings in Amazon Bedrock
The advantage of multi-modal embeddings lies in their ability to capture rich relationships and dependencies between different types of data. We can use these embeddings in various applications, such as image captioning, video analysis, sentiment analysis in multimedia content, recommendation systems, and more, where understanding and integrating information from multiple modalities is crucial.
1
2
3
4
5
6
!pip install boto3>=1.28.57
!pip install awscli>=1.29.57
!pip install botocore>=1.31.57
!pip install opensearch-py==2.3.1
!pip install pypdf>=3.8,<4
!pip install matplotlib==3.8.2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Invoke the titan multi modal to generate embeddings
# Store the generated embeddings in a flat file
import os
import json
import base64
import boto3
bedrock = boto3.client(
service_name="bedrock-runtime", region_name="us-east-1", endpoint_url="https://bedrock-runtime.us-east-1.amazonaws.com"
)
def get_embedding_for_productimage_and_description(image_path, image_product_description):
with open(image_path, "rb") as image_file:
input_image = base64.b64encode(image_file.read()).decode('utf8')
body = json.dumps(
{
"inputImage": input_image,
"inputText": image_product_description
}
)
response = bedrock.invoke_model( body=body,
modelId="amazon.titan-embed-image-v1", accept="application/json", contentType="application/json"
)
vector_json = json.loads(response['body'].read().decode('utf8'))
return vector_json
# Paths to the JSON files
json_file_paths = ['./fashion_data/train_data.json','./fashion_data/val_data.json']
# Limit the number of records to process in each block
block_size = 1000
# Loop through each JSON file
for json_file_path in json_file_paths:
with open(json_file_path, 'r') as json_file:
# Process records in blocks of block_size
documents = []
for line_num, line in enumerate(json_file):
try:
# Load each line as a separate JSON object
record = json.loads(line)
image_path = record['image_path']
image_product_description = record['product_title']
image_brand = record['brand']
image_class = record['class_label']
image_url = record['image_url']
# Assuming you have a function `create_image_embedding` already defined
multimodal_vector = get_embedding_for_productimage_and_description("./fashion_data/"+image_path,image_product_description)
# Create the embedding_request_body
embedding_request_body = {
"image_path": image_path,
"image_product_description": image_product_description,
"image_brand": image_brand,
"image_class": image_class,
"image_url": image_url,
"multimodal_vector": multimodal_vector['embedding']
}
# Add the current record to the list of documents
documents.append(embedding_request_body)
print(embedding_request_body)
# Check if it's time to write a block of documents
if line_num > 0 and line_num % block_size == 0:
# Write the documents to a file
print('writing file')
block_num = line_num // block_size
output_file_path = f"embedding_requests_{os.path.basename(json_file_path)}_block{block_num}.json"
with open(output_file_path, 'w') as output_file:
json.dump(documents, output_file, indent=2)
print(f"Processed {len(documents)} records and saved to {output_file_path}")
# Reset the documents list for the next block
documents = []
except Exception as ex:
print(f"Error processing record: {ex}")
# Check if there are remaining documents after processing all lines
if documents:
# Write the remaining documents to a file
output_file_path = f"embedding_requests_{os.path.basename(json_file_path)}_block{block_num + 1}.json"
with open(output_file_path, 'w') as output_file:
json.dump(documents, output_file, indent=2)
print(f"Processed {len(documents)} records and saved to {output_file_path}")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
[
{
"image_path": "images/train/3.jpeg",
"image_product_description": "Women's Digital Cotton Linen Blend Saree with Unstitched Blouse Piece(DigiPatta)",
"image_brand": "PERFECTBLUE",
"image_class": "saree",
"image_url": "https://m.media-amazon.com/images/I/81Y+je7CEgL._AC_UL320_.jpg",
"multimodal_vector": [
0.00060864864,
0.016107135,
0.008389459,
0.01386008,
........
........
]
},
{
"image_path": "images/train/4.jpeg",
"image_product_description": "Designer Sarees Women's Banarasi Cotton Silk Saree With Blouse Piece.",
"image_brand": "VAIVIDHYAM",
"image_class": "saree",
"image_url": "https://m.media-amazon.com/images/I/61B8o9UlqpL._AC_UL320_.jpg",
"multimodal_vector": [
-0.0140264565,
0.03815178,
0.0105772475,
........
........
]
}
]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Create the opensearch serverless collection
# create the opensearch client
import boto3
import time
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
vector_store_name = 'product-search-multimodal'
index_name = "product-search-multimodal-index"
encryption_policy_name = "product-search-multimodal-ep"
network_policy_name = "product-search-multimodal-np"
access_policy_name = 'product-search-multimodal-ap'
identity = boto3.client('sts').get_caller_identity()['Arn']
aoss_client = boto3.client('opensearchserverless')
security_policy = aoss_client.create_security_policy(
name = encryption_policy_name,
policy = json.dumps(
{
'Rules': [{'Resource': ['collection/' + vector_store_name],
'ResourceType': 'collection'}],
'AWSOwnedKey': True
}),
type = 'encryption'
)
network_policy = aoss_client.create_security_policy(
name = network_policy_name,
policy = json.dumps(
[
{'Rules': [{'Resource': ['collection/' + vector_store_name],
'ResourceType': 'collection'}],
'AllowFromPublic': True}
]),
type = 'network'
)
collection = aoss_client.create_collection(name=vector_store_name,type='VECTORSEARCH')
while True:
status = aoss_client.list_collections(collectionFilters={'name':vector_store_name})['collectionSummaries'][0]['status']
if status in ('ACTIVE', 'FAILED'): break
time.sleep(10)
access_policy = aoss_client.create_access_policy(
name = access_policy_name,
policy = json.dumps(
[
{
'Rules': [
{
'Resource': ['collection/' + vector_store_name],
'Permission': [
'aoss:CreateCollectionItems',
'aoss:DeleteCollectionItems',
'aoss:UpdateCollectionItems',
'aoss:DescribeCollectionItems'],
'ResourceType': 'collection'
},
{
'Resource': ['index/' + vector_store_name + '/*'],
'Permission': [
'aoss:CreateIndex',
'aoss:DeleteIndex',
'aoss:UpdateIndex',
'aoss:DescribeIndex',
'aoss:ReadDocument',
'aoss:WriteDocument'],
'ResourceType': 'index'
}],
'Principal': [identity],
'Description': 'Easy data policy'}
]),
type = 'data'
)
host = collection['createCollectionDetail']['id'] + '.' + os.environ.get("AWS_DEFAULT_REGION", None) + '.aoss.amazonaws.com:443'
print(host)
service = 'aoss'
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, os.environ.get("AWS_DEFAULT_REGION", None), service)
client = OpenSearch(
hosts = [{'host': host, 'port': 443}],
http_auth = auth,
use_ssl = True,
verify_certs = True,
connection_class = RequestsHttpConnection,
pool_maxsize = 20
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#create the vector index
import json
index_name = "product-search-multimodal-index"
index_body = {
"mappings": {
"properties": {
"image_path": {"type": "text"},
"image_product_description": {"type": "text"},
"image_brand": {"type": "text"},
"image_class": {"type": "text"},
"image_url": {"type": "text"},
"multimodal_vector": {
"type": "knn_vector",
"dimension": 1024,
"method":
{
"engine": "nmslib",
"space_type": "cosinesimil",
"name": "hnsw",
"parameters": {"ef_construction": 512, "m": 16},
},
},
}
},
"settings": {
"index": {
"number_of_shards": 2,
"knn.algo_param": {"ef_search": 512},
"knn": True,
}
},
}
try:
response = client.indices.create(index_name, body=index_body)
print(json.dumps(response, indent=2))
except Exception as ex: print(ex)
# describe new vector index
try:
response = client.indices.get("product-search-multimodal-index")
print(json.dumps(response, indent=2))
except Exception as ex: print(ex)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Read embeddings from flat file and insert into opensearch
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from opensearchpy import helpers
from opensearchpy.helpers import bulk
# Directory containing the JSON files
json_files_directory = "./"
# Iterate through each JSON file
for filename in os.listdir(json_files_directory):
if filename.startswith("embedding_requests") and filename.endswith(".json"):
file_path = os.path.join(json_files_directory, filename)
# Load JSON file
with open(file_path, "r") as file:
data = json.load(file)
print(f"Staring indexing for :: {filename}")
# Use the bulk API to insert documents for the current file
success, failed = bulk(
client,
data,
index="product-search-multimodal-index", # Replace with your OpenSearch index name
raise_on_exception=True
)
print(f"Indexed {success} documents successfully, {failed} documents failed for file: {filename}")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Test
def get_embedding_for_text(text):
body = json.dumps(
{
"inputText": text
}
)
response = bedrock.invoke_model(
body=body,
modelId="amazon.titan-embed-image-v1",
accept="application/json",
contentType="application/json"
)
vector_json = json.loads(response['body'].read().decode('utf8'))
return vector_json, text
text_embedding = get_embedding_for_text("Georgette Pink saree")
query = {
"size": 5,
"query": {
"knn": {
"multimodal_vector": {
"vector": text_embedding[0]['embedding'],
"k": 5
}
}
},
"_source": ["image_product_description", "image_path", "image_brand", "image_class", "image_url"]
}
try:
text_based_search_response = client.search(body=query,
index="product-search-multimodal-index")
print(json.dumps(text_based_search_response, indent=2))
except Exception as ex:
print(ex)
search_response variable initialized in the above code block, which holds the documents retrieved from OpenSearch.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import textwrap
# Display images and metadata
rows = 2
columns = 5
fig = plt.figure(figsize=(15, 8))
for index, value in enumerate(text_based_search_response["hits"]["hits"], 1):
ax = fig.add_subplot(rows, columns, index)
image_path = value["_source"]["image_path"]
image = np.array(Image.open(f'./fashion_data/{image_path}'))
# Display the image
ax.imshow(image)
ax.axis("off")
# Display product description and score below the image
product_description = value["_source"]["image_product_description"]
wrapped_description = "\n".join(textwrap.wrap(product_description, width=20))
score = f'Score: {value["_score"]:.2%}'
# Set title and description below the image
title = f'{wrapped_description}\n{score}'
ax.set_title(title, fontsize=8, pad=2)
plt.tight_layout()
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Test
import os
import json
import base64
import boto3
def get_embedding_for_image(image_path):
with open(image_path, "rb") as image_file:
input_image = base64.b64encode(image_file.read()).decode('utf8')
body = json.dumps(
{
"inputImage": input_image
}
)
response = bedrock.invoke_model(
body=body,
modelId="amazon.titan-embed-image-v1",
accept="application/json",
contentType="application/json"
)
vector_json = json.loads(response['body'].read().decode('utf8'))
return vector_json, input_image
image_embedding = get_embedding_for_image("./fashion_data/images/test/1358.jpeg")
query = {
"size": 5,
"query": {
"knn": {
"multimodal_vector": {
"vector": image_embedding[0]['embedding'],
"k": 5
}
}
},
"_source": ["image_product_description", "image_path", "image_brand", "image_class", "image_url"]
}
try:
image_based_search_response = client.search(body=query,
index="product-search-multimodal-index")
print(json.dumps(image_based_search_response, indent=2))
except Exception as ex:
print(ex)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import textwrap
# Display images and metadata
rows = 2
columns = 5
fig = plt.figure(figsize=(15, 8))
for index, value in enumerate(image_based_search_response["hits"]["hits"], 1):
ax = fig.add_subplot(rows, columns, index)
image_path = value["_source"]["image_path"]
image = np.array(Image.open(f'./fashion_data/{image_path}'))
# Display the image
ax.imshow(image)
ax.axis("off")
# Display product description and score below the image
product_description = value["_source"]["image_product_description"]
wrapped_description = "\n".join(textwrap.wrap(product_description, width=20))
score = f'Score: {value["_score"]:.2%}'
# Set title and description below the image
title = f'{wrapped_description}\n{score}'
ax.set_title(title, fontsize=8, pad=2)
plt.tight_layout()
plt.show()
Any opinions in this post are those of the individual author and may not reflect the opinions of AWS.