import enum
import warnings
from pydantic.utils import get_model
from pydantic.schema import schema, get_flat_models_from_model, get_model_name_map
from .helper import _OpenAPIGenBaseModel, inherit_fom_basemodel
# list of top level class names that we should stop at
STOPPAGE = set(['NoExtraBaseModel', 'ModelMetaclass', 'BaseModel', 'object', 'Enum'])
[docs]def get_schemas_inheritance(model_cls):
"""This method modifies the default OpenAPI from Pydantic.
It adds referenced values to subclasses using allOf field as explained in this post:
https://swagger.io/docs/specification/data-models/inheritance-and-polymorphism
"""
# get a dictionary that maps the name of each model schema to its Pydantic class.
model_name_map = get_model_mapper(model_cls, STOPPAGE, full=True, include_enum=False)
# get the standard OpenAPI schema for Pydantic for all the new objects
ref_prefix = '#/components/schemas/'
schemas = \
schema(model_name_map.values(), ref_prefix=ref_prefix)['definitions']
# add the [possibly] needed baseclass to the list of classes
schemas['_OpenAPIGenBaseModel'] = dict(eval(_OpenAPIGenBaseModel.schema_json()))
model_name_map['_OpenAPIGenBaseModel'] = _OpenAPIGenBaseModel
# An empty dictionary to collect updated objects
updated_schemas = {}
# iterate through all the data models
# find the ones which are subclassed and updated them based on the properties of
# baseclasses.
for name in schemas.keys():
# find the class object from class name
try:
main_cls = model_name_map[name]
except KeyError:
# enum objects are not included.
if 'enum' in schemas[name]:
continue
warnings.warn(f'***KeyError: {name} key not found.***')
top_classes = []
else:
top_classes = get_ancestors(main_cls)
if not top_classes:
# update the object to inherit from baseclass which only has type
# this is required for dotnet bindings
if name != '_OpenAPIGenBaseModel':
updated_schemas[name] = inherit_fom_basemodel(schemas[name])
continue
# Do the real work and update the current schema to use inheritance
updated_schemas[name] = set_inheritance(name, top_classes, schemas)
# replace updated schemas in original schema
for name, value in updated_schemas.items():
schemas[name] = value
return schemas
[docs]def get_ancestors(cls):
# use type.mro to go through all the ancestors for this class and collect them
top_classes = []
for cls in type.mro(cls):
if cls.__name__ in STOPPAGE:
break
top_classes.append(cls)
if len(top_classes) < 2:
# this class is not a subclass
return []
else:
return top_classes
def _check_object_types(source, target, prop):
"""Check if objects with same name have different types.
In such a case we need to subclass from one higher level.
"""
if 'type' in source:
if source['type'] != 'array':
return source['type'] != target[prop]
else:
# for an array check both the type and the type for items
return (source['type'], source['items']) != target[prop]
[docs]def set_inheritance(name, top_classes, schemas):
"""Set inheritance for an object.
Args:
name: name of the object.
top_classes: List of ancestors for this class.
schemas: A dictionary of all the schema objects.
Returns:
Dict - updated schema for the object with the input name.
"""
# this is the list of special keys that we copy in manually
copied_keys = set(['type', 'properties', 'required', 'additionalProperties'])
# remove the class itself
print(f'\nProcessing {name}')
top_classes = top_classes[1:]
top_class = top_classes[0]
tree = ['....' * (i + 1) + c.__name__ for i, c in enumerate(top_classes)]
print('\n'.join(tree))
# the immediate top class openapi schema
object_dict = schemas[name]
if 'enum' in object_dict:
return object_dict
# collect required and properties from top classes and do not include them in
# the object itself so we don't end up with duplicate values in the schema for
# the subclass - if it is required then it will be provided upstream.
top_classes_required = []
top_classes_prop = {}
# collect required keys
for t in top_classes:
try:
schema_t = schemas[t.__name__]
except KeyError as error:
raise KeyError(f'Failed to find the model name: {error}')
try:
tc_required = schema_t['required']
except KeyError:
# no required field
continue
for r in tc_required:
top_classes_required.append(r)
# collect properties
for t in top_classes:
tc_prop = schemas[t.__name__]['properties']
for pn, dt in tc_prop.items():
# collect type for every field. This is helpful to catch the cases where
# the same field name has a different new type in the subclass and should be
# kept to overwrite the original field.
if 'type' in dt:
if dt['type'] == 'array':
# collect both the type and the type for its items
top_classes_prop[pn] = dt['type'], dt['items']
else:
top_classes_prop[pn] = dt['type']
else:
top_classes_prop[pn] = '###' # no type means use of oneOf or allOf
# create a new schema for this object based on the top level class
data = {
'allOf': [
{
'$ref': f'#/components/schemas/{top_class.__name__}'
},
{
'type': 'object',
'required': [],
'properties': {}
}
]
}
data_copy = dict(data)
if not top_classes_required and 'required' in object_dict:
# no required in top level class
# add all the required to the subclass
for r in object_dict['required']:
data_copy['allOf'][1]['required'].append(r)
elif 'required' in object_dict and top_classes_required:
# only add the new required fields
for r in object_dict['required']:
if r not in top_classes_required:
data_copy['allOf'][1]['required'].append(r)
# no required fields - delete it from the dictionary
if len(data_copy['allOf'][1]['required']) == 0:
del(data_copy['allOf'][1]['required'])
# get full list of the properties and add the ones that doesn't exist in
# ancestor objects.
properties = object_dict['properties']
for prop, values in properties.items():
if prop not in top_classes_prop:
# new field. add it to the properties
print(f'Extending: {prop}')
data_copy['allOf'][1]['properties'][prop] = values
elif _check_object_types(values, top_classes_prop, prop) \
or 'type' not in values and ('allOf' in values or 'anyOf' in values):
# same name different types
print(f'Found a field with the same name: {prop}.')
if len(top_classes) > 1:
print(f'Trying {name} against {top_classes[1].__name__}.')
return set_inheritance(name, top_classes, schemas)
else:
# try against a base object.
print(f'Trying {name} against OpenAPI base object.')
_top_classes = [_OpenAPIGenBaseModel, _OpenAPIGenBaseModel]
return set_inheritance(name, _top_classes, schemas)
try:
data_copy['allOf'][1]['properties']['type'] = properties['type']
except KeyError:
print(f'Found object with no type: {name}')
if 'additionalProperties' in object_dict:
data_copy['allOf'][1]['additionalProperties'] = \
object_dict['additionalProperties']
# add other items in addition to copied_keys
for key, value in schemas[name].items():
if key in copied_keys:
continue
data_copy[key] = value
return data_copy
[docs]def get_model_mapper(models, stoppage=None, full=True, include_enum=False):
"""Get a dictionary of name: class for all the objects in model."""
no_enums = [model for model in models if not isinstance(model, enum.EnumMeta)]
enums = [model for model in models if isinstance(model, enum.EnumMeta)]
flat_models = []
for model in no_enums:
try:
f_models = get_flat_models_from_model(get_model(model))
except TypeError:
warnings.warn(f'*** Invalid input objetc: {model}')
else:
flat_models.extend(f_models)
flat_models = list(set(flat_models))
if include_enum:
flat_models.extend(enums)
# this is the list of all the referenced objects
model_name_map = get_model_name_map(flat_models)
# flip the dictionary so I can access each class by name
model_name_map = {v: k for k, v in model_name_map.items()}
if full:
stoppage = stoppage or set(
['NoExtraBaseModel', 'ModelMetaclass', 'BaseModel', 'object', 'str', 'Enum']
)
# Pydantic does not necessarily add all the baseclasses to the OpenAPI
# documentation. We check all of them and them to the list if they are not
# already added
models = list(model_name_map.values())
for model in models:
for cls in type.mro(model):
if cls.__name__ in stoppage:
break
if cls.__name__ not in model_name_map:
model_name_map[cls.__name__] = cls
# filter out enum objects
if not include_enum:
model_name_map = {
k: v for k, v in model_name_map.items()
if not isinstance(v, enum.EnumMeta)
}
# remove base type objects
model_name_map = {
k: v for k, v in model_name_map.items()
if k not in ('str', 'int', 'dict')
}
assert len(model_name_map) > 0, 'Found no valid Pydantic model in input classes.'
return model_name_map
[docs]def class_mapper(models, find_and_replace=None):
"""Create a mapper between OpenAPI models and Python modules.
This mapper is used by dotnet generator to organize the models under similar
module structure.
Args:
models: Input Pydantic models.
find_and_replace: A list of two string values for pattern and what it should be
replaced with.
"""
if not hasattr(models, '__iter__'):
models = [models]
mapper = get_model_mapper(models, full=True, include_enum=True)
# add enum classes to mapper
schemas = get_schemas_inheritance(models)
enums = {}
for name in schemas:
s = schemas[name]
if 'enum' in s:
# add enum
info = mapper[name]
if info.__name__ not in enums:
enums[info.__name__] = info
module_mapper = {}
# remove enum from mapper
classes = {k: c.__module__ for k, c in mapper.items() if k not in enums}
enums = {k: c.__module__ for k, c in enums.items()}
if find_and_replace:
fi, rep = find_and_replace
for k, v in classes.items():
classes[k] = v.replace(fi, rep)
for k, v in enums.items():
enums[k] = v.replace(fi, rep)
# this sorting only works in python3.7+
module_mapper['classes'] = {k: classes[k] for k in sorted(classes)}
module_mapper['enums'] = {k: enums[k] for k in sorted(enums)}
return module_mapper