Add authorization to digest-score
This commit is contained in:
@ -16,13 +16,16 @@ from urllib.parse import urlparse
|
||||
from datetime import datetime
|
||||
import dateutil.parser
|
||||
from google.cloud import storage
|
||||
from features.user_history import FEATURE_COLUMNS
|
||||
|
||||
import concurrent.futures
|
||||
from threading import Lock, RLock
|
||||
from collections import ChainMap
|
||||
import copy
|
||||
|
||||
from features.user_history import FEATURE_COLUMNS
|
||||
from auth import user_token_required, admin_token_required
|
||||
|
||||
|
||||
class ThreadSafeUserFeatures:
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
@ -219,12 +222,14 @@ def metrics():
|
||||
|
||||
|
||||
@app.route('/refresh', methods=['GET'])
|
||||
@admin_token_required
|
||||
def refresh():
|
||||
refresh_data()
|
||||
return jsonify({'OK': 'yes'}), 200
|
||||
|
||||
|
||||
@app.route('/users/<user_id>/features', methods=['GET'])
|
||||
@admin_token_required
|
||||
def get_user_features(user_id):
|
||||
result = {}
|
||||
df_user = pd.DataFrame([{
|
||||
@ -243,12 +248,13 @@ def get_user_features(user_id):
|
||||
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
@user_token_required
|
||||
def predict():
|
||||
try:
|
||||
data = request.get_json()
|
||||
app.logger.info(f"predict scoring request: {data}")
|
||||
|
||||
user_id = data.get('user_id')
|
||||
user_id = request.user_id
|
||||
item_features = data.get('item_features')
|
||||
|
||||
if user_id is None:
|
||||
@ -263,14 +269,13 @@ def predict():
|
||||
|
||||
|
||||
@app.route('/batch', methods=['POST'])
|
||||
@user_token_required
|
||||
def batch():
|
||||
start = timer()
|
||||
try:
|
||||
data = request.get_json()
|
||||
app.logger.info(f"batch scoring request: {data}")
|
||||
|
||||
items = data.get('items')
|
||||
user_id = data.get('user_id')
|
||||
user_id = request.user_id
|
||||
if user_id == None:
|
||||
return jsonify({'error': 'no user_id supplied'}), 400
|
||||
if len(items) > 101:
|
||||
|
||||
45
ml/digest-score/auth.py
Normal file
45
ml/digest-score/auth.py
Normal file
@ -0,0 +1,45 @@
|
||||
import os
|
||||
import jwt
|
||||
from flask import request, jsonify
|
||||
from functools import wraps
|
||||
|
||||
SECRET_KEY = os.getenv('JWT_SECRET')
|
||||
ADMIN_SECRET_KEY = os.getenv('JWT_ADMIN_SECRET_KEY')
|
||||
|
||||
def user_token_required(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
token = None
|
||||
if 'Authorization' in request.headers:
|
||||
print("request.headers['Authorization'].split(" ")[1]", request.headers['Authorization'].split(" ")[1])
|
||||
token = request.headers['Authorization'].split(" ")[1]
|
||||
if not token:
|
||||
return jsonify({'message': 'Token is missing!'}), 401
|
||||
try:
|
||||
data = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
|
||||
request.user_id = data['uid']
|
||||
except jwt.ExpiredSignatureError:
|
||||
return jsonify({'message': 'Token has expired!'}), 401
|
||||
except jwt.InvalidTokenError:
|
||||
return jsonify({'message': 'Token is invalid!'}), 401
|
||||
return f(*args, **kwargs)
|
||||
return decorated
|
||||
|
||||
def admin_token_required(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
token = None
|
||||
if 'Authorization' in request.headers:
|
||||
token = request.headers['Authorization'].split(" ")[1]
|
||||
if not token:
|
||||
return jsonify({'message': 'Token is missing!'}), 401
|
||||
try:
|
||||
data = jwt.decode(token, ADMIN_SECRET_KEY, algorithms=["HS256"])
|
||||
if data['role'] != 'admin':
|
||||
return jsonify({'message': 'Admin token required!'}), 403
|
||||
except jwt.ExpiredSignatureError:
|
||||
return jsonify({'message': 'Token has expired!'}), 401
|
||||
except jwt.InvalidTokenError:
|
||||
return jsonify({'message': 'Token is invalid!'}), 401
|
||||
return f(*args, **kwargs)
|
||||
return decorated
|
||||
@ -9,5 +9,6 @@ sklearn2pmml
|
||||
sqlalchemy
|
||||
pyarrow
|
||||
requests
|
||||
PyJWT
|
||||
prometheus_client
|
||||
xgboost==2.1.0
|
||||
|
||||
@ -3,6 +3,7 @@ import client from 'prom-client'
|
||||
import { env } from '../env'
|
||||
import { registerMetric } from '../prometheus'
|
||||
import { logError } from '../utils/logger'
|
||||
import { createWebAuthToken } from '../routers/auth/jwt_helpers'
|
||||
|
||||
export interface Feature {
|
||||
library_item_id?: string
|
||||
@ -77,10 +78,12 @@ class ScoreClientImpl implements ScoreClient {
|
||||
|
||||
async getScores(data: ScoreApiRequestBody): Promise<ScoreApiResponse> {
|
||||
const start = Date.now()
|
||||
const authToken = createWebAuthToken(data.user_id)
|
||||
|
||||
try {
|
||||
const response = await axios.post<ScoreApiResponse>(this.apiUrl, data, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${authToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 5000,
|
||||
|
||||
Reference in New Issue
Block a user