import { hc } from 'hono/client';

import {
  type ClientWebSocketMessage,
  ServerWebSocketMessageSchema,
} from '@mai/types';

import { logger } from '@utils/logger';

let ws: WebSocket | null;

type ChannelSubscriptionOptions = {
  conversation_updated: [{ conversationId: string }, onMessage: () => void];
  search_request_updated: [{ searchRequestId: string }, () => void];
  chat_completion_chunk: [
    {
      conversationId: string;
    },
    (args: {
      conversationId: string;
      conversationMessageId: string;
      content: string;
      contentType: 'chunk' | 'information';
    }) => void,
  ];
  analysis_request_run_event: [
    {
      analysisRequestRunId: string;
    },
    (args: { analysisRequestRunId: string }) => void,
  ];
  insight_request_run_event: [
    {
      insightRequestRunId: string;
    },
    (args: {
      insightRequestRunId: string;
      label: string;
      progress: number;
    }) => void,
  ];
};

type ChannelType = keyof ChannelSubscriptionOptions;

const conversationUpdatedSubscriptions = new Map<
  string,
  ChannelSubscriptionOptions['conversation_updated']
>();

const searchRequestUpdatedSubscriptions = new Map<
  string,
  ChannelSubscriptionOptions['search_request_updated']
>();

const chatCompletionChunkSubscriptions = new Map<
  string,
  ChannelSubscriptionOptions['chat_completion_chunk']
>();

const analysisRequestRunEventSubscriptions = new Map<
  string,
  ChannelSubscriptionOptions['analysis_request_run_event']
>();

const insightRequestRunEventSubscriptions = new Map<
  string,
  ChannelSubscriptionOptions['insight_request_run_event']
>();

/**
 * subscribes to a particular topic, when new messages are received on the topic, the onMessage callback is called for applicable subscriptions
 */
export function subscribe<Channel extends ChannelType = ChannelType>(
  channel: Channel,
  ...args: ChannelSubscriptionOptions[Channel]
) {
  const token = localStorage.getItem('sessionToken');
  switch (channel) {
    case 'conversation_updated': {
      const validatedArgs =
        args as ChannelSubscriptionOptions['conversation_updated'];
      const [{ conversationId }] = validatedArgs;
      conversationUpdatedSubscriptions.set(conversationId, validatedArgs);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'subscribe',
              options: {
                channel: 'conversation_updated',
                conversationId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error({ error: e }, 'Error subscribing to conversation_updated');
      }
      break;
    }
    case 'search_request_updated': {
      const validatedArgs =
        args as ChannelSubscriptionOptions['search_request_updated'];
      const [{ searchRequestId }] = validatedArgs;
      searchRequestUpdatedSubscriptions.set(searchRequestId, validatedArgs);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'subscribe',
              options: {
                channel: 'search_request_updated',
                searchRequestId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error subscribing to search_request_updated',
        );
      }
      break;
    }
    case 'chat_completion_chunk': {
      const validatedArgs =
        args as ChannelSubscriptionOptions['chat_completion_chunk'];
      const [{ conversationId }] = validatedArgs;
      chatCompletionChunkSubscriptions.set(conversationId, validatedArgs);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'subscribe',
              options: {
                channel: 'chat_completion_chunk',
                conversationId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error subscribing to chat_completion_chunk',
        );
      }
      break;
    }
    case 'analysis_request_run_event': {
      const validatedArgs =
        args as ChannelSubscriptionOptions['analysis_request_run_event'];
      const [{ analysisRequestRunId }] = validatedArgs;
      analysisRequestRunEventSubscriptions.set(
        analysisRequestRunId,
        validatedArgs,
      );
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'subscribe',
              options: {
                channel: 'analysis_request_run_event',
                analysisRequestRunId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error subscribing to analysis_request_run_event',
        );
      }
      break;
    }
    case 'insight_request_run_event': {
      const validatedArgs =
        args as ChannelSubscriptionOptions['insight_request_run_event'];
      const [{ insightRequestRunId }] = validatedArgs;
      insightRequestRunEventSubscriptions.set(
        insightRequestRunId,
        validatedArgs,
      );
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'subscribe',
              options: {
                channel: 'insight_request_run_event',
                insightRequestRunId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error subscribing to insight_request_run_event',
        );
      }
      break;
    }
  }
}

export function unsubscribe<Channel extends ChannelType = ChannelType>(
  channel: Channel,
  args: ChannelSubscriptionOptions[Channel][0],
) {
  const token = localStorage.getItem('sessionToken');
  switch (channel) {
    case 'conversation_updated': {
      const { conversationId } =
        args as ChannelSubscriptionOptions['conversation_updated'][0];
      if (!conversationUpdatedSubscriptions.has(conversationId)) return;
      conversationUpdatedSubscriptions.delete(conversationId);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'unsubscribe',
              options: {
                channel: 'conversation_updated',
                conversationId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error unsubscribing from conversation_updated',
        );
      }
      break;
    }
    case 'search_request_updated': {
      const { searchRequestId } =
        args as ChannelSubscriptionOptions['search_request_updated'][0];
      if (!searchRequestUpdatedSubscriptions.has(searchRequestId)) return;
      searchRequestUpdatedSubscriptions.delete(searchRequestId);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'unsubscribe',
              options: {
                channel: 'search_request_updated',
                searchRequestId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error unsubscribing from search_request_updated',
        );
      }
      break;
    }
    case 'chat_completion_chunk': {
      const { conversationId } =
        args as ChannelSubscriptionOptions['chat_completion_chunk'][0];
      if (!chatCompletionChunkSubscriptions.has(conversationId)) return;
      chatCompletionChunkSubscriptions.delete(conversationId);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'unsubscribe',
              options: {
                channel: 'chat_completion_chunk',
                conversationId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error unsubscribing from chat_completion_chunk',
        );
      }
      break;
    }
    case 'analysis_request_run_event': {
      const { analysisRequestRunId } =
        args as ChannelSubscriptionOptions['analysis_request_run_event'][0];
      if (!analysisRequestRunEventSubscriptions.has(analysisRequestRunId))
        return;
      analysisRequestRunEventSubscriptions.delete(analysisRequestRunId);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'unsubscribe',
              options: {
                channel: 'analysis_request_run_event',
                analysisRequestRunId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error unsubscribing from analysis_request_run_event',
        );
      }
      break;
    }
    case 'insight_request_run_event': {
      const { insightRequestRunId } =
        args as ChannelSubscriptionOptions['insight_request_run_event'][0];
      if (!insightRequestRunEventSubscriptions.has(insightRequestRunId)) return;
      insightRequestRunEventSubscriptions.delete(insightRequestRunId);
      if (!ws || !token) break;
      try {
        ws.send(
          JSON.stringify({
            token,
            payload: {
              type: 'unsubscribe',
              options: {
                channel: 'insight_request_run_event',
                insightRequestRunId,
              },
            },
          } satisfies ClientWebSocketMessage),
        );
      } catch (e) {
        logger.error(
          { error: e },
          'Error unsubscribing from insight_request_run_event',
        );
      }
      break;
    }
  }
}

function startWebsocket() {
  const client = hc(
    import.meta.env.MODE === 'development'
      ? 'http://localhost:3001'
      : 'https://api.moderately.ai',
  );
  ws = client.ws.$ws(0);

  ws.addEventListener('open', () => {
    logger.info({}, 'WebSocket connection opened');

    for (const [
      options,
      onMessage,
    ] of conversationUpdatedSubscriptions.values()) {
      subscribe('conversation_updated', options, onMessage);
    }

    for (const [
      options,
      onMessage,
    ] of searchRequestUpdatedSubscriptions.values()) {
      subscribe('search_request_updated', options, onMessage);
    }

    for (const [
      options,
      onMessage,
    ] of chatCompletionChunkSubscriptions.values()) {
      subscribe('chat_completion_chunk', options, onMessage);
    }
  });

  ws.addEventListener('close', () => {
    logger.info({}, 'WebSocket connection closed');
    ws = null;
  });

  ws.addEventListener('message', (event) => {
    const message = ServerWebSocketMessageSchema.parse(JSON.parse(event.data));
    const type = message.payload.type;
    if (type === 'chat_completion_chunk') {
      const { conversationId, conversationMessageId, content, contentType } =
        message.payload;
      chatCompletionChunkSubscriptions.get(conversationId)?.[1]({
        conversationId,
        conversationMessageId,
        content,
        contentType,
      });
    } else if (type === 'conversation_updated') {
      const { conversationId } = message.payload;
      conversationUpdatedSubscriptions.get(conversationId)?.[1]();
    } else if (type === 'search_request_updated') {
      const { searchRequestId } = message.payload;
      searchRequestUpdatedSubscriptions.get(searchRequestId)?.[1]();
    } else if (type === 'analysis_request_run_event') {
      const { analysisRequestRunId } = message.payload;
      analysisRequestRunEventSubscriptions.get(analysisRequestRunId)?.[1]({
        analysisRequestRunId,
      });
    } else if (type === 'insight_request_run_event') {
      const { insightRequestRunId, label, progress } = message.payload;
      insightRequestRunEventSubscriptions.get(insightRequestRunId)?.[1]({
        insightRequestRunId,
        label,
        progress,
      });
    } else {
      logger.warn({ type }, 'Received unknown message type');
    }
  });
}

// Initial connection setup
startWebsocket();

// Reconnect WebSocket on interval in case of disconnects
setInterval(() => {
  setTimeout(() => {
    if (!ws) {
      startWebsocket();
      logger.debug({}, 'Reconnecting WebSocket');
    } else {
      logger.debug({}, 'WebSocket connection already open');
    }
  });
}, 2000);
