diff --git a/packages/api/src/apollo.ts b/packages/api/src/apollo.ts index 68e1c3645..26725b499 100644 --- a/packages/api/src/apollo.ts +++ b/packages/api/src/apollo.ts @@ -15,6 +15,7 @@ import { import { ApolloServer } from 'apollo-server-express' import { ExpressContext } from 'apollo-server-express/dist/ApolloServer' import { ApolloServerPlugin } from 'apollo-server-plugin-base' +import DataLoader from 'dataloader' import { Express } from 'express' import * as httpContext from 'express-http-context2' import type http from 'http' @@ -30,6 +31,8 @@ import { functionResolvers } from './resolvers/function_resolvers' import { ClaimsToSet, RequestContext, ResolverContext } from './resolvers/types' import ScalarResolvers from './scalars' import typeDefs from './schema' +import { batchGetHighlightsFromLibraryItemIds } from './services/highlights' +import { batchGetLabelsFromLibraryItemIds } from './services/labels' import { countDailyServiceUsage, createServiceUsage, @@ -100,6 +103,10 @@ const contextFunc: ContextFunction = async ({ dataSources: { readingProgress: new ReadingProgressDataSource(), }, + dataLoaders: { + labels: new DataLoader(batchGetLabelsFromLibraryItemIds), + highlights: new DataLoader(batchGetHighlightsFromLibraryItemIds), + }, } return ctx diff --git a/packages/api/src/resolvers/function_resolvers.ts b/packages/api/src/resolvers/function_resolvers.ts index 2a6396474..e6b997332 100644 --- a/packages/api/src/resolvers/function_resolvers.ts +++ b/packages/api/src/resolvers/function_resolvers.ts @@ -24,11 +24,6 @@ import { } from '../generated/graphql' import { getAISummary } from '../services/ai-summaries' import { findUserFeatures } from '../services/features' -import { - findHighlightsByLibraryItemId, - highlightsLoader, -} from '../services/highlights' -import { labelsLoader } from '../services/labels' import { findRecommendationsByLibraryItemId } from '../services/recommendation' import { findUploadFileById } from '../services/upload_file' import { @@ -443,7 +438,7 @@ export const functionResolvers = { if (article.labels) return article.labels if (article.labelNames?.length) { - return labelsLoader.load(article.id) + return ctx.dataLoaders.labels.load(article.id) } return [] @@ -519,7 +514,7 @@ export const functionResolvers = { if (item.labels) return item.labels if (item.labelNames?.length) { - return labelsLoader.load(item.id) + return ctx.dataLoaders.labels.load(item.id) } return [] @@ -566,7 +561,7 @@ export const functionResolvers = { if (item.highlights) return item.highlights if (item.highlightAnnotations?.length) { - const highlights = await highlightsLoader.load(item.id) + const highlights = await ctx.dataLoaders.highlights.load(item.id) return highlights.map(highlightDataToHighlight) } diff --git a/packages/api/src/resolvers/types.ts b/packages/api/src/resolvers/types.ts index 6fc24405f..927dee9a1 100644 --- a/packages/api/src/resolvers/types.ts +++ b/packages/api/src/resolvers/types.ts @@ -1,11 +1,14 @@ /* eslint-disable @typescript-eslint/ban-types */ import { Span } from '@opentelemetry/api' import { Context as ApolloContext } from 'apollo-server-core' +import DataLoader from 'dataloader' import * as jwt from 'jsonwebtoken' import { EntityManager } from 'typeorm' import winston from 'winston' -import { PubsubClient } from '../pubsub' import { ReadingProgressDataSource } from '../datasources/reading_progress_data_source' +import { Highlight } from '../entity/highlight' +import { Label } from '../entity/label' +import { PubsubClient } from '../pubsub' export interface Claims { uid: string @@ -41,6 +44,10 @@ export interface RequestContext { dataSources: { readingProgress: ReadingProgressDataSource } + dataLoaders: { + labels: DataLoader + highlights: DataLoader + } } export type ResolverContext = ApolloContext diff --git a/packages/api/src/services/highlights.ts b/packages/api/src/services/highlights.ts index 0b5c527d2..5978ac676 100644 --- a/packages/api/src/services/highlights.ts +++ b/packages/api/src/services/highlights.ts @@ -22,7 +22,7 @@ export type HighlightEvent = Merge< EntityEvent > -const batchGetHighlightsFromLibraryItemIds = async ( +export const batchGetHighlightsFromLibraryItemIds = async ( libraryItemIds: readonly string[] ): Promise => { const libraryItems = await authTrx(async (tx) => @@ -43,10 +43,6 @@ const batchGetHighlightsFromLibraryItemIds = async ( ) } -export const highlightsLoader = new DataLoader( - batchGetHighlightsFromLibraryItemIds -) - export const getHighlightLocation = (patch: string): number | undefined => { const dmp = new diff_match_patch() const patches = dmp.patch_fromText(patch) diff --git a/packages/api/src/services/labels.ts b/packages/api/src/services/labels.ts index 2d0d5c0af..5cbcbd372 100644 --- a/packages/api/src/services/labels.ts +++ b/packages/api/src/services/labels.ts @@ -24,7 +24,7 @@ export type LabelEvent = Merge< EntityEvent > -const batchGetLabelsFromLibraryItemIds = async ( +export const batchGetLabelsFromLibraryItemIds = async ( libraryItemIds: readonly string[] ): Promise => { const libraryItems = await authTrx((tx) => @@ -41,8 +41,6 @@ const batchGetLabelsFromLibraryItemIds = async ( ) } -export const labelsLoader = new DataLoader(batchGetLabelsFromLibraryItemIds) - export const findOrCreateLabels = async ( labels: CreateLabelInput[], userId: string