package grpcstarter.extensions.transcoding;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import grpcstarter.extensions.transcoding.Transcoder;
import grpcstarter.extensions.transcoding.Util;
import grpcstarter.server.GrpcServerProperties;
import grpcstarter.server.GrpcServerStartedEvent;
import io.grpc.BindableService;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import jakarta.annotation.Nonnull;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.context.ApplicationListener;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.util.StreamUtils;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.servlet.function.HandlerFunction;
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;

/* loaded from: input_file:grpcstarter/extensions/transcoding/DefaultServletTranscoder.class */
public class DefaultServletTranscoder implements ServletTranscoder, DisposableBean, ApplicationListener<GrpcServerStartedEvent> {
    private static final String MATCHING_ROUTE = String.valueOf(DefaultServletTranscoder.class) + ".matchingRoute";
    private final Map<String, Util.Route<ServerRequest>> autoMappingRoutes = new HashMap();
    private final List<Util.Route<ServerRequest>> customRoutes = new ArrayList();
    private final HeaderConverter headerConverter;
    private final GrpcTranscodingProperties grpcTranscodingProperties;
    private final GrpcServerProperties grpcServerProperties;
    private final TranscodingExceptionResolver transcodingExceptionResolver;
    private Channel channel;

    public DefaultServletTranscoder(List<BindableService> list, HeaderConverter headerConverter, GrpcTranscodingProperties grpcTranscodingProperties, GrpcServerProperties grpcServerProperties, TranscodingExceptionResolver transcodingExceptionResolver) {
        Util.fillRoutes(list, this.autoMappingRoutes, this.customRoutes, grpcTranscodingProperties);
        this.headerConverter = headerConverter;
        this.grpcTranscodingProperties = grpcTranscodingProperties;
        this.grpcServerProperties = grpcServerProperties;
        this.transcodingExceptionResolver = transcodingExceptionResolver;
    }

    public void onApplicationEvent(GrpcServerStartedEvent grpcServerStartedEvent) {
        this.channel = Util.getTranscodingChannel(grpcServerStartedEvent.getSource().getPort(), this.grpcTranscodingProperties, this.grpcServerProperties);
    }

    @Nonnull
    public Optional<HandlerFunction<ServerResponse>> route(@Nonnull ServerRequest serverRequest) {
        Util.Route<ServerRequest> route;
        if (Objects.equals(serverRequest.method(), HttpMethod.POST) && (route = this.autoMappingRoutes.get(Util.trim(serverRequest.path(), '/'))) != null) {
            serverRequest.attributes().put(MATCHING_ROUTE, route);
            return Optional.of(this);
        }
        for (Util.Route<ServerRequest> route2 : this.customRoutes) {
            if (route2.predicate().test(serverRequest) || route2.additionalPredicates().stream().anyMatch(predicate -> {
                return predicate.test(serverRequest);
            })) {
                serverRequest.attributes().put(MATCHING_ROUTE, route2);
                return Optional.of(this);
            }
        }
        return Optional.empty();
    }

    @Nonnull
    public ServerResponse handle(@Nonnull ServerRequest serverRequest) {
        Util.Route<ServerRequest> route = (Util.Route) serverRequest.attributes().get(MATCHING_ROUTE);
        MethodDescriptor.MethodType type = route.invokeMethod().getType();
        if (type == MethodDescriptor.MethodType.UNARY) {
            return processUnaryCall(serverRequest, route);
        }
        if (type == MethodDescriptor.MethodType.SERVER_STREAMING) {
            return processServerStreamingCall(serverRequest, route);
        }
        throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Unsupported rpc method type: " + String.valueOf(type));
    }

    private static ClientCall<Object, Object> getCall(Channel channel, Util.Route<ServerRequest> route) {
        return channel.newCall(route.invokeMethod(), CallOptions.DEFAULT);
    }

    private static Transcoder getTranscoder(ServerRequest serverRequest) {
        try {
            return Transcoder.create(new Transcoder.Variable(StreamUtils.copyToByteArray(serverRequest.servletRequest().getInputStream()), serverRequest.servletRequest().getParameterMap(), (Map) serverRequest.servletRequest().getAttribute(Util.URI_TEMPLATE_VARIABLES_ATTRIBUTE)));
        } catch (IOException e) {
            throw new IllegalStateException("getInputStream failed", e);
        }
    }

    private ServerResponse processUnaryCall(ServerRequest serverRequest, Util.Route<ServerRequest> route) {
        AtomicReference atomicReference = new AtomicReference();
        AtomicReference atomicReference2 = new AtomicReference();
        Transcoder transcoder = getTranscoder(serverRequest);
        try {
            Message message = (Message) ClientCalls.blockingUnaryCall(getCall(ClientInterceptors.intercept(this.channel, new ClientInterceptor[]{MetadataUtils.newCaptureMetadataInterceptor(atomicReference, atomicReference2), MetadataUtils.newAttachHeadersInterceptor(this.headerConverter.toMetadata(serverRequest.headers().asHttpHeaders()))}), route), getMessage(route, transcoder));
            ServerResponse.BodyBuilder headers = ServerResponse.ok().headers(httpHeaders -> {
                Metadata metadata = (Metadata) atomicReference.get();
                if (metadata != null) {
                    httpHeaders.addAll(this.headerConverter.toHttpHeaders(metadata));
                }
            });
            Object out = transcoder.out(message, route.httpRule());
            if (JsonUtil.canParseJson(out)) {
                headers.contentType(MediaType.APPLICATION_JSON);
            }
            return headers.body(JsonUtil.toJson(out));
        } catch (StatusRuntimeException e) {
            return this.transcodingExceptionResolver.resolve(e);
        }
    }

    private ServerResponse processServerStreamingCall(ServerRequest serverRequest, Util.Route<ServerRequest> route) {
        Transcoder transcoder = getTranscoder(serverRequest);
        Message message = getMessage(route, transcoder);
        ClientCall<Object, Object> call = getCall(ClientInterceptors.intercept(this.channel, new ClientInterceptor[]{MetadataUtils.newAttachHeadersInterceptor(this.headerConverter.toMetadata(serverRequest.headers().asHttpHeaders()))}), route);
        return ServerResponse.sse(sseBuilder -> {
            sseBuilder.onError(th -> {
                call.cancel("SSE error", (Throwable) null);
            });
            ClientCalls.asyncServerStreamingCall(call, message, new StreamObserver<Object>() { // from class: grpcstarter.extensions.transcoding.DefaultServletTranscoder.1
                public void onNext(Object obj) {
                    sseBuilder.data(JsonUtil.toJson(transcoder.out((Message) obj, route.httpRule())));
                }

                public void onError(Throwable th2) {
                    if (!(th2 instanceof StatusRuntimeException)) {
                        sseBuilder.error(th2);
                    } else {
                        StatusRuntimeException statusRuntimeException = (StatusRuntimeException) th2;
                        sseBuilder.error(new TranscodingRuntimeException(TranscodingUtil.toHttpStatus(statusRuntimeException.getStatus()), statusRuntimeException.getMessage(), null));
                    }
                }

                public void onCompleted() {
                    sseBuilder.complete();
                }
            });
        }, Duration.ZERO);
    }

    private static Message getMessage(Util.Route<?> route, Transcoder transcoder) {
        try {
            return Util.buildRequestMessage(transcoder, route);
        } catch (InvalidProtocolBufferException e) {
            throw new ResponseStatusException(HttpStatus.BAD_REQUEST, e.getMessage(), e);
        }
    }

    public void destroy() throws Exception {
        Util.shutdown(this.channel, Duration.ofSeconds(15L));
    }
}
