logo
Menu
Websocket follow-up, image creation with Gen AI

Websocket follow-up, image creation with Gen AI

Follow-up to previous post

Published Jun 3, 2024

Intro

So I mentioned in the previous post that I could have done this even better making use of a websocket api instead of constantly polling the api gateway. This got me thinking and I could not help my self and started working on it straight away..
The changes, visualized:
diagram
Diagram

What is websockets?

Websockets is a technology that allows websites and applications to have real-time, two-way communication with servers. It can be explained like having a direct conversation with someone, where both sides can speak and listen at the same time, whenever they want.

How it works simply explained

Normally, when you visit a website, your browser sends a request to the server, and the server sends back a response. This process repeats every time you need new information.
Using websockets, after an initial handshake (like saying hello to a person) using the usual web protocol (HTTP), the connection switches to a WebSocket connection. This special connection stays open, allowing data to be sent back and forth instantly and continuously. It’s like keeping the phone line open when talking to someone, instead of having to dial each other up every time.
Websockets is for example useful for applications that need real-time updates, such as live chat systems, online games, and live notifications. By keeping the connection open and allowing instant communication, websockets help make these applications faster and more responsive.
To explain it shortly, websockets create a direct, open communication line between your browser and a server, enabling fast, efficient, and real-time interactions.

Why is websockets relevant for this project?

Since I set up this project using eventBridge and step functions to handle the image generation process, the process starts with API Gateway (with a request from my frontend). However, because I don't know how long the image generation will take, it's not practical to use a normal API that waits for a response.
Using a normal API means the client sends a request and waits for the server to respond. If the image generation process takes a long time, the client would have to wait the entire time, which isn't efficient and can cause the browser to time out or appear unresponsive to the user. This can be problematic for tasks like image generation, where the processing time can vary depending on the detail of the image etc. To handle this before, I did not return a response to the frontend and used a polling method instead. This meant that I had to send requests to the backend constantly to check if the image had been created. This lead to many requests to the backend.
That is why I wanted to use websockets instead. Since websockets keep the connection open I have the possibility to reply to the frontend as soon as the image is ready, no matter how long it takes. With this I ensure that the frontend gets the result in real-time without constantly checking the server for updates. This makes the whole process more efficient and improves the user experience.

What different in the project from tha last post?

Frontend

As mentioned in the last post I used a polling mechanism with React query and also triggered the image generation process with a request. They both looked like this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// The function to trigger the generation of an image
const { mutate } = useMutation({
mutationFn: () => {
setLoading(true);
return sendPost("/generate-image", {
text: post,
s3Key: `image/${id}`,
bedrock: bedrock,
});
},
});
// The function for polling the /get-image endpoint. When status 200 is returned polling stops and image is displayed
const { data } = useQuery({
queryKey: ["image"],
refetchInterval: 3000,
refetchIntervalInBackground: true,
enabled: refetch,
queryFn: async () => {
const imageData = await getImage("/get-image", { s3Key: `image/${id}` });
if (imageData.status === 200) {
setRefetch(false);
setImage(imageData.data);
setLoading(false);
}
return imageData.data;
},
});
The useQuery function requests the backend every 3 seconds until the response is of status OK and I get the image back. This works very well but it leads to alot of request and can be made more efficient.
Since I later will explain how I changed my api to a websocket api instead I had to make some changes to the frontend as well. What I did was:
  • Remove both the mutation and the useQuery function
  • Uninstall React-Query
  • Create a context in react for websocket requests and use the build in lib for websockets.
  • Use context for requests.
I wrapped my whole react app in this context so it would be available to use inside my component.
1
2
3
4
5
6
7
function App() {
return (
<WebSocketProvider>
<MainContent />
</WebSocketProvider>

);
}
The context looks like this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import React, { createContext, useContext, useEffect, useState } from "react";

const API_URL = process.env.REACT_APP_API_URL || "";

// The type of what is returned by this context to the component
interface WebSocketContextType {
socket: WebSocket | null;
sendMessage: (message: object) => void;
isLoading: boolean;
url: string | null;
}

interface WebSocketMessage {
url: string;
}

const WebSocketContext = createContext<WebSocketContextType | undefined>(
undefined
);

export const WebSocketProvider: React.FC<{ children: React.ReactNode }> = ({
children,
}
) =>
{
const [socket, setSocket] = useState<WebSocket | null>(null);
const [isLoading, setIsLoading] = useState<boolean>(false);
const [url, setUrl] = useState<string | null>(null);

useEffect(() => {
const ws = new WebSocket(API_URL);
setSocket(ws);

ws.onopen = () => {
console.log("Connected to WebSocket");
};

ws.onmessage = (event) => {
try {
const data: WebSocketMessage = JSON.parse(event.data);
setUrl(data.url);
} catch (error) {
console.error("Error parsing WebSocket message: ", error);
} finally {
setIsLoading(false);
ws.close();
}
};

ws.onerror = (error) => {
console.error("WebSocket Error: ", error);
};

ws.onclose = () => {
console.log("WebSocket connection closed");
};

return () => {
ws.close();
};
}, []);

const sendMessage = (message: object) => {
if (socket && socket.readyState === WebSocket.OPEN) {
setIsLoading(true);
socket.send(JSON.stringify(message));
}
};

return (
<WebSocketContext.Provider value={{ socket, sendMessage, isLoading, url }}>
{children}
</WebSocketContext.Provider>

);
};

export const useWebSocket = (): WebSocketContextType => {
const context = useContext(WebSocketContext);
if (!context) {
throw new Error("useWebSocket must be used within a WebSocketProvider");
}
return context;
};
What important is happening here?
  • I initialize a few state variables:
    • socket: This holds the websocket instance.
    • isLoading: Holds the information if I have sent a message and are waiting for a reply. (So I can tell the user that it is loading)
    • url: Stores the url from the websocket reply.
  • UseEffect:
    • Runs once when the component is called.
    • Establishes a WebSocket connection.
  • sendMessage function:
    • Sends the prompt via websocket to api gateway.
    • Sets loading to true which means I can give the user feedback that we are waiting for a reply.
  • onmessage:
    • If a message is received I parse the response so I can use the url to show the image in the component.
    • Finally sets loading to false to indicate that the image is generated.
In my component the code now looks like this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import React from "react";
import { Button, Flex, Layout, Spin, Image, Switch } from "antd";
import TextArea from "antd/es/input/TextArea";
import { Typography } from "antd";
import { useWebSocket } from "../../contexts/websocket";

const { Paragraph, Text } = Typography;

const { Header, Content } = Layout;

const headerStyle: React.CSSProperties = {
textAlign: "center",
color: "#fff",
height: 64,
paddingInline: 48,
lineHeight: "64px",
backgroundColor: "#00415a",
};

const contentStyle: React.CSSProperties = {
textAlign: "center",
minHeight: 120,
lineHeight: "120px",
color: "#fff",
backgroundColor: "#00719c",
};

const layoutStyle = {
borderRadius: 8,
overflow: "hidden",
width: "calc(50% - 8px)",
maxWidth: "calc(50% - 8px)",
marginTop: "10vh",
height: "100%",
};

const textAreaStyle = {
width: "80%",
height: "80%",
};

const buttonStyle = {
width: "80%",
height: "80%",
marginBottom: 16,
backgroundColor: "#009bd6",
};
const paragraphStyle = {
margin: 10,
};

const textStyle = {
color: "white",
};

export default function MainContent() {
const id = "testing12asswwxc222";
const { sendMessage, isLoading, url } = useWebSocket();
const [bedrock, setBedrock] = React.useState(false);
const [text, setText] = React.useState("");

const handleBedrockChange = (checked: boolean) => {
setBedrock(checked);
};

const handleSendMessage = () => {
const message = {
action: "generate",
text: text,
s3Key: id,
bedrock,
};
sendMessage(message);
setText("");
};

return (
<Spin spinning={isLoading}>
<Flex justify="center">
<Layout style={layoutStyle}>
<Header style={headerStyle}>Epic Post To Image POC</Header>
<Content style={contentStyle}>
<Paragraph style={paragraphStyle}>
<Text strong style={textStyle}>
Write a post as you would for a social media platform. Click
generate to create an image for your post.
</Text>
.
</Paragraph>

<TextArea
style={textAreaStyle}
rows={4}
placeholder="Write here"
onChange={(e) =>
setText(e.target.value)}
/>
<Paragraph style={paragraphStyle}>
<Text strong style={textStyle}>
Use Bedrock? (Toggle to use bedrock)
</Text>

<Switch onChange={handleBedrockChange} />
</Paragraph>
{url && (
<Image width={"80%"} style={{ marginTop: "10px" }} src={url} />
)}
<Button
type="primary"
style={buttonStyle}
onClick={handleSendMessage}
>

Generate
</Button>
</Content>
</Layout>
</Flex>
</Spin>

);
}
When sending a message to the websocket api I need to tell the websocket which route to use, meaning I need to tell the websocket I want to call for the generate route.

Backend

To convert our api gateway api to a websocket api a few changes had to be made.

In the sam configuration the following changes has been done

  • The openApi definition of api gateway has been removed.
  • Websocket api definitions has been added:
1
2
3
4
5
6
WebSocketApi:
Type: AWS::ApiGatewayV2::Api
Properties:
Name: WebSocketApi
ProtocolType: WEBSOCKET
RouteSelectionExpression: "$request.body.action"
To open a connection to this api, a route had to be defined for the sole purpose of opening a connection. This route is also integrated with a lambda:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
ConnectRoute:
Type: AWS::ApiGatewayV2::Route
Properties:
ApiId: !Ref WebSocketApi
RouteKey: $connect
AuthorizationType: NONE
OperationName: ConnectRoute
Target: !Join
- "/"
- - "integrations"
- !Ref ConnectIntegration

ConnectIntegration:
Type: AWS::ApiGatewayV2::Integration
Properties:
ApiId: !Ref WebSocketApi
Description: Connect Integration
IntegrationType: AWS_PROXY
IntegrationUri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${ConnectFunction.Arn}/invocations
When having a connect route I also needed a disconnect route and a disconnect lambda that looks very similar:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
DisconnectRoute:
Type: AWS::ApiGatewayV2::Route
Properties:
ApiId: !Ref WebSocketApi
RouteKey: $disconnect
AuthorizationType: NONE
OperationName: DisconnectRoute
Target: !Join
- "/"
- - "integrations"
- !Ref DisconnectIntegration

DisconnectIntegration:
Type: AWS::ApiGatewayV2::Integration
Properties:
ApiId: !Ref WebSocketApi
Description: Disconnect Integration
IntegrationType: AWS_PROXY
IntegrationUri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${DisconnectFunction.Arn}/invocations
And finally I also needed a route and integration to generate the image:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
GenerateRoute:
Type: AWS::ApiGatewayV2::Route
Properties:
ApiId: !Ref WebSocketApi
RouteKey: generate
AuthorizationType: NONE
OperationName: GenerateRoute
Target: !Join
- "/"
- - "integrations"
- !Ref GenerateIntegration

GenerateIntegration:
Type: AWS::ApiGatewayV2::Integration
Properties:
ApiId: !Ref WebSocketApi
Description: forward Integration
IntegrationType: AWS_PROXY
IntegrationUri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EventBridgeProxyFunction.Arn}/invocations
Before when I used the regular api and not the websocket api I had the possibility to integrate api gateway directly to eventBridge. Now that was no longer possible so I decided to use a lambda function as a proxy between api gateway and eventBridge:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
EventBridgeProxyFunction:
Type: AWS::Serverless::Function
Properties:
Handler: bootstrap
Runtime: provided.al2023
CodeUri: eventbridge-proxy/
Environment:
Variables:
EVENT_BUS_NAME: !Ref PostToImageEventBus
Policies:
- Statement:
- Effect: Allow
Action:
- events:PutEvents
Resource: !Sub "arn:aws:events:${AWS::Region}:${AWS::AccountId}:event-bus/${PostToImageEventBus}"

EventBridgeProxyPermission:
Type: AWS::Lambda::Permission
DependsOn:
- WebSocketApi
Properties:
Action: lambda:InvokeFunction
FunctionName: !Ref EventBridgeProxyFunction
Principal: apigateway.amazonaws.com
The connect and disconnect functions:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
ConnectFunction:
Type: AWS::Serverless::Function
Properties:
Handler: bootstrap
Runtime: provided.al2023
CodeUri: connect/

ConnectPermission:
Type: AWS::Lambda::Permission
DependsOn:
- WebSocketApi
Properties:
Action: lambda:InvokeFunction
FunctionName: !Ref ConnectFunction
Principal: apigateway.amazonaws.com

DisconnectFunction:
Type: AWS::Serverless::Function
Properties:
Handler: bootstrap
Runtime: provided.al2023
CodeUri: disconnect/

DisconnectPermission:
Type: AWS::Lambda::Permission
DependsOn:
- WebSocketApi
Properties:
Action: lambda:InvokeFunction
FunctionName: !Ref DisconnectFunction
Principal: apigateway.amazonaws.com

Lambdas

The three new lambdas are connect, disconnect and the eventBridge-proxy.
Connect
In this function I am only returning status 200 to tell the client that the connection has been established.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package main

import (
"context"
"fmt"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
)

func handler(ctx context.Context, request events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) {
fmt.Println("Connect Event:", request)
return events.APIGatewayProxyResponse{StatusCode: 200, Body: "Connected"}, nil
}

func main() {
lambda.Start(handler)
}
Disconnect
In this function I am only returning status 200 to tell the client that the connection has been disconnected.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package main

import (
"context"
"fmt"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
)

func handler(ctx context.Context, request events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) {
fmt.Println("Disconnect Event:", request)
return events.APIGatewayProxyResponse{StatusCode: 200, Body: "Disconnected"}, nil
}

func main() {
lambda.Start(handler)
}
EventBridge proxy
This function replaces the previous integration to eventbridge. What the lambda is doing is just forwarding the request from the api to eventBridge:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package main

import (
"context"
"encoding/json"
"fmt"
"log"
"os"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/eventbridge"
)

type Payload struct {
Text string `json:"text"`
S3Key string `json:"s3Key"`
Bedrock bool `json:"bedrock"`
}

type EventDetail struct {
ConnectionId string `json:"connectionId"`
Text string `json:"text"`
S3Key string `json:"s3Key"`
Bedrock bool `json:"bedrock"`
}

var svc *eventbridge.EventBridge

func init() {
sesh := session.Must(session.NewSession())
svc = eventbridge.New(sesh)
}

func handler(ctx context.Context, request events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) {
eventBusName := os.Getenv("EVENT_BUS_NAME")
if eventBusName == "" {
log.Println("Error: EVENT_BUS_NAME environment variable is not set")
return events.APIGatewayProxyResponse{StatusCode: 400}, fmt.Errorf("missing EVENT_BUS_NAME")
}

connectionId := request.RequestContext.ConnectionID

var payload Payload
err := json.Unmarshal([]byte(request.Body), &payload)
if err != nil {
log.Printf("Error parsing JSON: %s\n", err)
return events.APIGatewayProxyResponse{StatusCode: 400}, nil
}

eventDetail := EventDetail{
ConnectionId: connectionId,
Text: payload.Text,
S3Key: payload.S3Key,
Bedrock: payload.Bedrock,
}

eventDetailJSON, err := json.Marshal(eventDetail)
if err != nil {
log.Printf("Error marshaling event detail: %s\n", err)
return events.APIGatewayProxyResponse{StatusCode: 500}, nil
}

input := &eventbridge.PutEventsInput{
Entries: []*eventbridge.PutEventsRequestEntry{
{
Detail: aws.String(string(eventDetailJSON)),
DetailType: aws.String("PreparePrompt"),
Source: aws.String("api-gateway"),
EventBusName: aws.String(eventBusName),
},
},
}

result, err := svc.PutEvents(input)
if err != nil {
log.Printf("Error putting event: %s\n", err)
return events.APIGatewayProxyResponse{StatusCode: 500}, nil
}

log.Printf("PutEvents result: %v\n", result)

return events.APIGatewayProxyResponse{StatusCode: 200}, nil
}

func main() {
lambda.Start(handler)
}
The other functions look pretty much the same as the last post with the difference that I am replying to the websocket in the end of the step function as following:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
url := fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucketName, region, objectKey)

apiClient := apigatewaymanagementapi.New(sesh, aws.NewConfig().WithEndpoint(websocketEndpoint))

postToConnectionInput := &apigatewaymanagementapi.PostToConnectionInput{
ConnectionId: aws.String(event.ConnectionId),
Data: []byte(fmt.Sprintf(`{"url": "%s"}`, url)),
}

_, err = apiClient.PostToConnection(postToConnectionInput)
if err != nil {
log.Println("Error sending response to WebSocket client:", err)
return err
}

log.Println("Successfully sent response to WebSocket client:", event.ConnectionId)

Summary

To summarize, using websockets makes the application run much more smoothly and reduces the number of requests being made. This results in a more pleasant user experience and lower bandwidth usage. However, there are some downsides to this approach. The IAC definitions become larger, and direct integration with EventBridge is no longer possible. Despite these issues, I prefer using websockets over polling because they just improve the experience and make everything run more efficiently.
To have a deeper look at the changes made all code can be found here.
 

Comments